diff --git a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm index 9a208e814cf..9f50a2343a4 100644 --- a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm +++ b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm @@ -1,3 +1,5 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include @@ -37,25 +39,12 @@ NSArray* IntArrayToNSArray(IntArrayRef arr) { } // anonymous namespace Tensor _fft_c2r_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { - TORCH_CHECK(self.is_complex()); - auto in_sizes = self.sizes(); - DimVector out_sizes(in_sizes.begin(), in_sizes.end()); - out_sizes[dim.back()] = last_dim_size; - auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type()))); + auto out = at::empty({}, self.options().dtype(c10::toRealValueType(self.scalar_type()))); return _fft_c2r_mps_out(self, dim, normalization, last_dim_size, out); } Tensor _fft_r2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { - TORCH_CHECK(self.is_floating_point()); - auto input_sizes = self.sizes(); - DimVector out_sizes(input_sizes.begin(), input_sizes.end()); - auto last_dim = dim.back(); - auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; - if (onesided) { - out_sizes[last_dim] = last_dim_halfsize; - } - - auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type()))); + auto out = at::empty({}, self.options().dtype(c10::toComplexType(self.scalar_type()))); return _fft_r2c_mps_out(self, dim, normalization, onesided, out); } @@ -72,6 +61,17 @@ using namespace mps; // TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237 Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) { + TORCH_CHECK(self.scalar_type() == kFloat || self.scalar_type() == kHalf, "Only float and half dtypes are supported"); + TORCH_CHECK(out.scalar_type() == c10::toComplexType(self.scalar_type())); + const auto input_sizes = self.sym_sizes(); + SymDimVector out_sizes(input_sizes.begin(), input_sizes.end()); + auto last_dim = dim.back(); + auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; + if (onesided) { + out_sizes[last_dim] = last_dim_halfsize; + } + at::native::resize_output_symint(out, out_sizes); + auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(onesided); @autoreleasepool { @@ -112,6 +112,12 @@ Tensor& _fft_c2r_mps_out(const Tensor& self, int64_t normalization, int64_t last_dim_size, Tensor& out) { + TORCH_CHECK(self.is_complex(), "Input must be complex"); + TORCH_CHECK(out.scalar_type() == c10::toRealValueType(self.scalar_type()), "Unexpected output type"); + const auto in_sizes = self.sym_sizes(); + SymDimVector out_sizes(in_sizes.begin(), in_sizes.end()); + out_sizes[dim.back()] = last_dim_size; + at::native::resize_output_symint(out, out_sizes); auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(last_dim_size); @autoreleasepool {