Align meta deducing for fft_r2c with fft_r2c_mkl on XPU (#156048)

There is a memory layout mismatching between `fft_r2c` XPU and Inductor meta deducing.
Original `fft_r2c` Inductor meta deducing for XPU backend is aligned with CPU (fallback). This PR is to correct the Inductor meta deducing and update the torch-xpu-ops commit to [intel/torch-xpu-ops@`3a9419c`](3a9419c8bb).
The XPU implementation first performs the R2C transform on the last dimension, followed by iterative C2C transforms on the remaining dimensions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156048
Approved by: https://github.com/guangyey, https://github.com/etaf, https://github.com/jansel
This commit is contained in:
Cui, Yifeng 2025-06-20 01:41:03 +00:00 committed by PyTorch MergeBot
parent 159a39ad34
commit 72c8751b61
2 changed files with 4 additions and 9 deletions

2
third_party/xpu.txt vendored
View File

@ -1 +1 @@
a3a196ccdbcbc399e157b6bcf8f5611e6561b6d6
3a9419c8bb6a98dd3e3cd473c36691fb4abeae40

View File

@ -499,14 +499,15 @@ def meta_fft_r2c(self, dim, normalization, onesided):
if onesided:
out_sizes[last_dim] = last_dim_halfsize
if device_hint(self) == "cuda":
if device_hint(self) == "cuda" or device_hint(self) == "xpu":
# _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
# _fft_r2c_xpu in torch-xpu-ops/src/ATen/native/xpu/SpectralOps.cpp
output = self.new_empty(
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
working_tensor = self
if use_optimized_cufft_path(dim):
if device_hint(self) == "cuda" and use_optimized_cufft_path(dim):
_exec_fft(output, working_tensor, out_sizes, dim, forward=True)
else:
# First do the R2C transform on the last dimension
@ -539,12 +540,6 @@ def meta_fft_r2c(self, dim, normalization, onesided):
return output
elif device_hint(self) == "xpu":
sorted_dims = _sort_dims(self, dim, exclude_last=True)
out = self.new_empty(
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
return _exec_fft(out, self, out_sizes, sorted_dims, forward=True)
else:
return self.new_empty(
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)