mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Intel GPU] Fix SDPA dummy LSE output to match meta function (#148652)
To fix XPU patched UTs including ```bash pytest -vs third_party/torch-xpu-ops/test/xpu/test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_nn_functional_scaled_dot_product_attention_xpu_bfloat16 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148652 Approved by: https://github.com/EikanWang
This commit is contained in:
parent
416ea1c71c
commit
243b47e2ec
|
|
@ -190,9 +190,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
|
|||
|
||||
auto opts = query.options();
|
||||
auto output = at::empty({batch_size, num_head, seq_len_q, head_dim}, opts);
|
||||
// auto logsumexp =
|
||||
// at::empty({batch_size, num_head, seq_len_q}, opts.dtype(at::kFloat));
|
||||
auto logsumexp = at::empty({}, opts.dtype(at::kFloat));
|
||||
at::Tensor logsumexp, debug_attn_mask; // not supported
|
||||
|
||||
at::native::onednn::gpu_float_sdpa(
|
||||
batch_size,
|
||||
|
|
@ -210,12 +208,9 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
|
|||
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim)),
|
||||
output);
|
||||
|
||||
// rng and debug mask not used
|
||||
// rng not used
|
||||
auto philox_seed = at::empty({}, at::dtype(at::kLong));
|
||||
auto philox_offset = at::empty({}, at::dtype(at::kLong));
|
||||
auto debug_attn_mask = at::empty(
|
||||
{batch_size, num_head, seq_len_q, seq_len_kv}, at::dtype(at::kFloat));
|
||||
|
||||
return std::make_tuple(
|
||||
output,
|
||||
logsumexp,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user