mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Align meta deducing for fft_r2c with fft_r2c_mkl on XPU (#156048)
There is a memory layout mismatching between `fft_r2c` XPU and Inductor meta deducing.
Original `fft_r2c` Inductor meta deducing for XPU backend is aligned with CPU (fallback). This PR is to correct the Inductor meta deducing and update the torch-xpu-ops commit to [intel/torch-xpu-ops@`3a9419c`](3a9419c8bb).
The XPU implementation first performs the R2C transform on the last dimension, followed by iterative C2C transforms on the remaining dimensions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156048
Approved by: https://github.com/guangyey, https://github.com/etaf, https://github.com/jansel
This commit is contained in:
parent
159a39ad34
commit
72c8751b61
2
third_party/xpu.txt
vendored
2
third_party/xpu.txt
vendored
|
|
@ -1 +1 @@
|
|||
a3a196ccdbcbc399e157b6bcf8f5611e6561b6d6
|
||||
3a9419c8bb6a98dd3e3cd473c36691fb4abeae40
|
||||
|
|
|
|||
|
|
@ -499,14 +499,15 @@ def meta_fft_r2c(self, dim, normalization, onesided):
|
|||
if onesided:
|
||||
out_sizes[last_dim] = last_dim_halfsize
|
||||
|
||||
if device_hint(self) == "cuda":
|
||||
if device_hint(self) == "cuda" or device_hint(self) == "xpu":
|
||||
# _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
|
||||
# _fft_r2c_xpu in torch-xpu-ops/src/ATen/native/xpu/SpectralOps.cpp
|
||||
output = self.new_empty(
|
||||
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
||||
)
|
||||
|
||||
working_tensor = self
|
||||
if use_optimized_cufft_path(dim):
|
||||
if device_hint(self) == "cuda" and use_optimized_cufft_path(dim):
|
||||
_exec_fft(output, working_tensor, out_sizes, dim, forward=True)
|
||||
else:
|
||||
# First do the R2C transform on the last dimension
|
||||
|
|
@ -539,12 +540,6 @@ def meta_fft_r2c(self, dim, normalization, onesided):
|
|||
|
||||
return output
|
||||
|
||||
elif device_hint(self) == "xpu":
|
||||
sorted_dims = _sort_dims(self, dim, exclude_last=True)
|
||||
out = self.new_empty(
|
||||
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
||||
)
|
||||
return _exec_fft(out, self, out_sizes, sorted_dims, forward=True)
|
||||
else:
|
||||
return self.new_empty(
|
||||
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user