Revert "[xpu][feature] Integrate OneDNN SDPA training forward/backward into XPU OVERRIDEABLE Backend (#162454)"

This reverts commit fd68d409ad.

Reverted https://github.com/pytorch/pytorch/pull/162454 on behalf of https://github.com/atalman due to internal build failure ([comment](https://github.com/pytorch/pytorch/pull/162454#issuecomment-3475009089))
This commit is contained in:
PyTorch MergeBot 2025-10-31 21:58:52 +00:00
parent 9970fb97ff
commit 2699f5410b
13 changed files with 49 additions and 236 deletions

View File

@ -40,37 +40,14 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
return true; return true;
} }
bool input_require_grad( bool check_no_grad(sdp::sdp_params const& params, bool debug) {
const at::Tensor& query, const bool any_inputs_require_grad = params.query.requires_grad() ||
const at::Tensor& key, params.key.requires_grad() || params.value.requires_grad();
const at::Tensor& value, const bool gradmode_enabled = at::GradMode::is_enabled();
const std::optional<at::Tensor>& attn_mask) { if (debug && any_inputs_require_grad && gradmode_enabled) {
return at::GradMode::is_enabled() && TORCH_WARN("Backward or grad to be supported.");
(query.requires_grad() || key.requires_grad() || value.requires_grad() ||
(attn_mask.has_value() && attn_mask.value().requires_grad()));
} }
return !any_inputs_require_grad || !gradmode_enabled;
bool check_grad(sdp::sdp_params const& params, bool debug) {
if (!input_require_grad(
params.query, params.key, params.value, params.attn_mask))
return true;
auto q_num_heads = params.query.sym_size(-3);
auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads;
if (debug && is_gqa)
TORCH_WARN(
"scale_dot_product_attention with gqa is not supported for gradient computation on xpu.");
bool attn_mask_needs_grad =
params.attn_mask.has_value() && params.attn_mask.value().requires_grad();
if (debug && attn_mask_needs_grad) {
TORCH_WARN(
"scale_dot_product_attention on xpu is not supported when attn_mask.requires_grad() == True.");
}
return !is_gqa && !attn_mask_needs_grad;
} }
bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) { bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
@ -88,7 +65,7 @@ bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
sdp::check_nonzero_sequence_lengths_dense, sdp::check_nonzero_sequence_lengths_dense,
sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>, sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>,
check_head_dim_size_xpu, check_head_dim_size_xpu,
check_grad); check_no_grad);
for (auto& constraint : constraints) { for (auto& constraint : constraints) {
if (!constraint(params, debug)) { if (!constraint(params, debug)) {
return false; return false;
@ -248,11 +225,10 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
double dropout_p, double dropout_p,
bool is_causal, bool is_causal,
bool return_debug_mask, bool return_debug_mask,
std::optional<double> scale, std::optional<double> scale) {
bool compute_logsumexp) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4, query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {B, H, T, K}"); "scaled_dot_product_fused_attention_overrideable_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) && (key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)), (key.size(2) == value.size(2)),
@ -269,9 +245,6 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && is_causal), !(attn_bias.has_value() && is_causal),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal"); "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && attn_bias.value().requires_grad()),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot have requires_grad=True");
const int64_t batch_size = query.size(0); const int64_t batch_size = query.size(0);
const int64_t num_head_q = query.size(1); const int64_t num_head_q = query.size(1);
@ -281,14 +254,11 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
const int64_t seq_len_q = query.size(2); const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2); const int64_t seq_len_kv = key.size(2);
at::Tensor attention; at::Tensor output;
std::vector<int64_t> attention_shape = { std::vector<int64_t> output_shape = {
batch_size, num_head_q, seq_len_q, head_dim_v}; batch_size, num_head_q, seq_len_q, head_dim_v};
alloc_with_matching_layout(query, attention, attention_shape); alloc_with_matching_layout(query, output, output_shape);
at::Tensor logsumexp, debug_attn_mask; // not supported
auto opts = query.options();
at::Tensor logsumexp =
at::empty({batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat));
at::native::onednn::sdpa( at::native::onednn::sdpa(
batch_size, batch_size,
@ -304,15 +274,15 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
attn_bias, attn_bias,
is_causal, is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)), scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)),
attention, output,
compute_logsumexp, false,
logsumexp); logsumexp);
// rng not used // rng not used
auto philox_seed = at::empty({}, at::dtype(at::kLong)); auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong)); auto philox_offset = at::empty({}, at::dtype(at::kLong));
return std::make_tuple( return std::make_tuple(
attention, output,
logsumexp, logsumexp,
/* cum_seq_q */ at::Tensor(), /* cum_seq_q */ at::Tensor(),
/* cum_seq_k */ at::Tensor(), /* cum_seq_k */ at::Tensor(),
@ -320,106 +290,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
seq_len_kv, seq_len_kv,
philox_seed, philox_seed,
philox_offset, philox_offset,
/*debug_attn_mask */ at::Tensor()); debug_attn_mask);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward_xpu(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& attn_bias,
std::array<bool, 4> grad_input_mask,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
std::optional<double> scale) {
TORCH_INTERNAL_ASSERT(
grad_out.dim() == 4 && out.dim() == 4 &&
grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) &&
grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {B, H, T, K}");
TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Accept only 4 dims inputs shape of {B, H, T, K}");
TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: K/V should have the same batch / seq / num_head");
TORCH_INTERNAL_ASSERT(
query.size(0) == grad_out.size(0) && query.size(1) == grad_out.size(1) &&
query.size(2) == grad_out.size(2),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q should have the same batch / num_head / seq_len as grad_out");
TORCH_INTERNAL_ASSERT(
query.size(3) == key.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q/K should have the same head_dim");
TORCH_INTERNAL_ASSERT(
value.size(3) == grad_out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: V should have the same head_dim as grad_out");
TORCH_INTERNAL_ASSERT(
query.size(1) == key.size(1),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: number of heads in K/V must equal to number of heads in Q");
TORCH_INTERNAL_ASSERT(
dropout_p == 0.0,
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0");
TORCH_INTERNAL_ASSERT(
logsumexp.dim() == 3 && logsumexp.size(0) == query.size(0) &&
logsumexp.size(1) == query.size(1) &&
logsumexp.size(2) == query.size(2) &&
"scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {B, H, T}");
std::optional<Tensor> attn_bias_opt;
if (attn_bias.defined()) {
attn_bias_opt = attn_bias;
}
const int64_t batch_size = query.size(0);
const int64_t num_head_q = query.size(1);
const int64_t num_head_kv = key.size(1);
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
auto grad_q = at::empty_like(query);
auto grad_k = at::empty_like(key);
auto grad_v = at::empty_like(value);
auto grad_attn_bias = attn_bias_opt.has_value()
? at::empty_like(attn_bias_opt.value())
: at::Tensor();
at::native::onednn::sdpa_backward(
batch_size,
num_head_q,
num_head_kv,
seq_len_q,
seq_len_kv,
head_dim_qk,
head_dim_v,
grad_out,
query,
key,
value,
out,
logsumexp,
attn_bias_opt,
is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(query.size(3))),
grad_q,
grad_k,
grad_v);
return std::make_tuple(
std::move(grad_q),
std::move(grad_k),
std::move(grad_v),
std::move(grad_attn_bias));
} }
REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu); REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu);

View File

@ -15095,7 +15095,7 @@
CPU: _scaled_dot_product_flash_attention_cpu CPU: _scaled_dot_product_flash_attention_cpu
tags: nondeterministic_seeded tags: nondeterministic_seeded
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None, bool compute_log_sumexp=True) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) - func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch: dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
XPU: _scaled_dot_product_fused_attention_overrideable_xpu XPU: _scaled_dot_product_fused_attention_overrideable_xpu
@ -15119,7 +15119,6 @@
variants: function variants: function
dispatch: dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward
XPU: _scaled_dot_product_fused_attention_overrideable_backward_xpu
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) - func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
dispatch: dispatch:

View File

@ -768,11 +768,8 @@ Tensor scaled_dot_product_attention(
return std::get<0>(out_and_lse); return std::get<0>(out_and_lse);
} }
case SDPBackend::overrideable: { case SDPBackend::overrideable: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
compute_logsumexp = compute_logsumexp ||
(at::GradMode::is_enabled() && attn_mask.has_value() && attn_mask.value().requires_grad());
auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable( auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable(
query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale, compute_logsumexp); query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale);
return std::get<0>(out_lse_softmax); return std::get<0>(out_lse_softmax);
} }
case SDPBackend::math: { case SDPBackend::math: {
@ -1018,8 +1015,7 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p, double dropout_p,
bool is_causal, bool is_causal,
bool return_debug_mask, bool return_debug_mask,
std::optional<double> scale, std::optional<double> scale) {
bool compute_logsumexp) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable not implemented. This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function "); TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable not implemented. This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function ");
} }

View File

@ -58,8 +58,7 @@ wrapper__scaled_dot_product_fused_attention_overrideable(
double dropout_p, double dropout_p,
bool is_causal, bool is_causal,
bool return_debug_mask, bool return_debug_mask,
std::optional<double> scale, std::optional<double> scale) {
bool compute_log_sumexp) {
return at::native::openreg::_scaled_dot_product_fused_attention_overrideable( return at::native::openreg::_scaled_dot_product_fused_attention_overrideable(
query, query,
key, key,
@ -68,8 +67,7 @@ wrapper__scaled_dot_product_fused_attention_overrideable(
dropout_p, dropout_p,
is_causal, is_causal,
return_debug_mask, return_debug_mask,
scale, scale);
compute_log_sumexp);
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>

View File

@ -47,8 +47,7 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p, double dropout_p,
bool is_causal, bool is_causal,
bool return_debug_mask, bool return_debug_mask,
std::optional<double> scale, std::optional<double> scale) {
bool compute_log_sumexp) {
const int64_t batch_size = query.size(0); const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1); const int64_t num_heads = query.size(1);
const int64_t head_dim_v = value.size(3); const int64_t head_dim_v = value.size(3);

View File

@ -39,8 +39,7 @@ _scaled_dot_product_fused_attention_overrideable(
double dropout_p, double dropout_p,
bool is_causal, bool is_causal,
bool return_debug_mask, bool return_debug_mask,
std::optional<double> scale, std::optional<double> scale);
bool compute_log_sumexp);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward( _scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor& grad_out, const at::Tensor& grad_out,

View File

@ -4387,7 +4387,7 @@ class TestSDPAXpuOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
@parametrize("fused_kernel", [SDPBackend.OVERRIDEABLE]) @parametrize("fused_kernel", [SDPBackend.MATH, SDPBackend.OVERRIDEABLE])
@parametrize("dtype", [torch.half, torch.bfloat16, torch.float32]) @parametrize("dtype", [torch.half, torch.bfloat16, torch.float32])
@parametrize("batch_size,n_head,q_size,kv_size,head_dim", [ @parametrize("batch_size,n_head,q_size,kv_size,head_dim", [
(2, 5, 9216, 9216, 64), (2, 5, 9216, 9216, 64),
@ -4426,7 +4426,7 @@ class TestSDPAXpuOnly(NNTestCase):
tol = Tolerances(5e-2, 5e-2) tol = Tolerances(5e-2, 5e-2)
if dtype is torch.float16: if dtype is torch.float16:
tol = Tolerances(1e-2, 1e-2) tol = Tolerances(1e-2, 1e-2)
mask_shape = [batch_size, 1, q_size, kv_size] mask_shape = [batch_size, 1, 1, kv_size]
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False) make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
q_shape = SdpaShape(batch_size, n_head, q_size, head_dim) q_shape = SdpaShape(batch_size, n_head, q_size, head_dim)
kv_shape = SdpaShape(batch_size, n_head, kv_size, head_dim) kv_shape = SdpaShape(batch_size, n_head, kv_size, head_dim)
@ -4435,6 +4435,14 @@ class TestSDPAXpuOnly(NNTestCase):
v = make_tensor(kv_shape) v = make_tensor(kv_shape)
q2, k2, v2 = q.clone(), k.clone(), v.clone() q2, k2, v2 = q.clone(), k.clone(), v.clone()
if train:
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
q2.requires_grad_(True)
k2.requires_grad_(True)
v2.requires_grad_(True)
# (B, nh, T, hs) # (B, nh, T, hs)
q = q.view(batch_size, q_size, n_head, head_dim).transpose(1, 2) q = q.view(batch_size, q_size, n_head, head_dim).transpose(1, 2)
k = k.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2) k = k.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
@ -4454,43 +4462,17 @@ class TestSDPAXpuOnly(NNTestCase):
v2 = v2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2) v2 = v2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2)
attn_mask2 = attn_mask.float() if attn_mask is not None else None attn_mask2 = attn_mask.float() if attn_mask is not None else None
if train: if fused_kernel == SDPBackend.MATH:
q = q.detach().clone().requires_grad_(True) actual = torch.ops.aten._scaled_dot_product_attention_math(
k = k.detach().clone().requires_grad_(True) q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
v = v.detach().clone().requires_grad_(True) elif fused_kernel == SDPBackend.OVERRIDEABLE:
q2 = q2.detach().clone().requires_grad_(True) actual = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
k2 = k2.detach().clone().requires_grad_(True) q, k, v, attn_bias=attn_mask, dropout_p=0.0, is_causal=is_causal)[0]
v2 = v2.detach().clone().requires_grad_(True)
with sdpa_kernel(backends=[fused_kernel]): math_ref = torch.ops.aten._scaled_dot_product_attention_math(
actual = F.scaled_dot_product_attention( q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal)[0]
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)
with sdpa_kernel(backends=[SDPBackend.MATH]): self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol)
math_ref = F.scaled_dot_product_attention(
q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal)
if dtype in [torch.float16, torch.bfloat16]:
math_ref = math_ref.to(dtype)
self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
if train:
loss = torch.mean(actual)
loss_ref = torch.mean(math_ref)
loss.backward()
loss_ref.backward()
grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
if dtype in [torch.float16, torch.bfloat16]:
grad_q_ref = grad_q_ref.to(dtype)
grad_k_ref = grad_k_ref.to(dtype)
grad_v_ref = grad_v_ref.to(dtype)
self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
class TestAttnBias(NNTestCase): class TestAttnBias(NNTestCase):

View File

@ -2907,7 +2907,7 @@
output_differentiability: [True, False, False, False, False, False, False, False, False] output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)
- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None, bool compute_log_sumexp=True) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) - name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False, False, False, False, False] output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)

View File

@ -5754,7 +5754,6 @@ def meta__scaled_dot_product_fused_attention_overrideable(
is_causal: bool = False, is_causal: bool = False,
return_debug_mask: bool = False, return_debug_mask: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
compute_log_sumexp: bool = True,
): ):
B = query.size(0) B = query.size(0)
H_Q = query.size(1) H_Q = query.size(1)
@ -5788,36 +5787,6 @@ def meta__scaled_dot_product_fused_attention_overrideable(
) )
@register_meta([aten._scaled_dot_product_fused_attention_overrideable_backward])
def meta__scaled_dot_product_fused_attention_overrideable_backward(
grad_out: Tensor,
query: Tensor,
key: Tensor,
value: Tensor,
attn_bias: Tensor,
grad_input_mask: list[bool],
out: Tensor,
logsumexp: Tensor,
cum_seq_q: Tensor,
cum_seq_k: Tensor,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: Tensor,
philox_offset: Tensor,
scale: Optional[float] = None,
):
grad_q = torch.empty_like(query)
grad_k = torch.empty_like(key)
grad_v = torch.empty_like(value)
grad_attn_bias = None
if attn_bias is not None:
grad_attn_bias = torch.empty_like(attn_bias)
return grad_q, grad_k, grad_v, grad_attn_bias
@register_meta( @register_meta(
[ [
aten._scaled_dot_product_flash_attention_backward, aten._scaled_dot_product_flash_attention_backward,

View File

@ -36,7 +36,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int32_t compute_log_sumexp, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);

View File

@ -42,7 +42,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_a
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int32_t compute_log_sumexp, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_grouped_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* offs, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_grouped_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* offs, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);

View File

@ -25,7 +25,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int32_t compute_log_sumexp, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);

View File

@ -15,7 +15,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHand
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int32_t compute_log_sumexp, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0);