[MPS] Better error checking for FFT ops (#166272)

Namely, error out rather than crash when out dtype is of an unexpected type
Resize output tensor to the expected size in `_out` operation, to prevent crash when tensor of an unexpected size is passed.
Preserve symbolic shapes whenever possible

Test plan: Run `python test_ops.py -v -k test_out_warning_fft_hfft_mps` for MPS device, without this change it crashes with `Error: Invalid KernelDAG, equalShape for destination failed'`, run `python ../test/test_ops.py -v -k test_dtypes_stft_mps`, without this change it crashes with `A complex mlir::Type does not have a corresponding complex MPSDataType"`, when input dtype is bfloat16
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166272
Approved by: https://github.com/kulinseth
This commit is contained in:
Nikita Shulga 2025-10-27 18:07:19 -07:00 committed by PyTorch MergeBot
parent 1425b40f29
commit add37bacda

View File

@ -1,3 +1,5 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/Resize.h>
#include <ATen/native/SpectralOpsUtils.h> #include <ATen/native/SpectralOpsUtils.h>
#include <ATen/native/mps/OperationUtils.h> #include <ATen/native/mps/OperationUtils.h>
@ -37,25 +39,12 @@ NSArray<NSNumber*>* IntArrayToNSArray(IntArrayRef arr) {
} // anonymous namespace } // anonymous namespace
Tensor _fft_c2r_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { Tensor _fft_c2r_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
TORCH_CHECK(self.is_complex()); auto out = at::empty({}, self.options().dtype(c10::toRealValueType(self.scalar_type())));
auto in_sizes = self.sizes();
DimVector out_sizes(in_sizes.begin(), in_sizes.end());
out_sizes[dim.back()] = last_dim_size;
auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
return _fft_c2r_mps_out(self, dim, normalization, last_dim_size, out); return _fft_c2r_mps_out(self, dim, normalization, last_dim_size, out);
} }
Tensor _fft_r2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { Tensor _fft_r2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
TORCH_CHECK(self.is_floating_point()); auto out = at::empty({}, self.options().dtype(c10::toComplexType(self.scalar_type())));
auto input_sizes = self.sizes();
DimVector out_sizes(input_sizes.begin(), input_sizes.end());
auto last_dim = dim.back();
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
if (onesided) {
out_sizes[last_dim] = last_dim_halfsize;
}
auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
return _fft_r2c_mps_out(self, dim, normalization, onesided, out); return _fft_r2c_mps_out(self, dim, normalization, onesided, out);
} }
@ -72,6 +61,17 @@ using namespace mps;
// TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237 // TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237
Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) { Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) {
TORCH_CHECK(self.scalar_type() == kFloat || self.scalar_type() == kHalf, "Only float and half dtypes are supported");
TORCH_CHECK(out.scalar_type() == c10::toComplexType(self.scalar_type()));
const auto input_sizes = self.sym_sizes();
SymDimVector out_sizes(input_sizes.begin(), input_sizes.end());
auto last_dim = dim.back();
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
if (onesided) {
out_sizes[last_dim] = last_dim_halfsize;
}
at::native::resize_output_symint(out, out_sizes);
auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" + auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" +
std::to_string(normalization) + ":" + std::to_string(onesided); std::to_string(normalization) + ":" + std::to_string(onesided);
@autoreleasepool { @autoreleasepool {
@ -112,6 +112,12 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
int64_t normalization, int64_t normalization,
int64_t last_dim_size, int64_t last_dim_size,
Tensor& out) { Tensor& out) {
TORCH_CHECK(self.is_complex(), "Input must be complex");
TORCH_CHECK(out.scalar_type() == c10::toRealValueType(self.scalar_type()), "Unexpected output type");
const auto in_sizes = self.sym_sizes();
SymDimVector out_sizes(in_sizes.begin(), in_sizes.end());
out_sizes[dim.back()] = last_dim_size;
at::native::resize_output_symint(out, out_sizes);
auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
std::to_string(normalization) + ":" + std::to_string(last_dim_size); std::to_string(normalization) + ":" + std::to_string(last_dim_size);
@autoreleasepool { @autoreleasepool {