mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Better error checking for FFT ops (#166272)
Namely, error out rather than crash when out dtype is of an unexpected type Resize output tensor to the expected size in `_out` operation, to prevent crash when tensor of an unexpected size is passed. Preserve symbolic shapes whenever possible Test plan: Run `python test_ops.py -v -k test_out_warning_fft_hfft_mps` for MPS device, without this change it crashes with `Error: Invalid KernelDAG, equalShape for destination failed'`, run `python ../test/test_ops.py -v -k test_dtypes_stft_mps`, without this change it crashes with `A complex mlir::Type does not have a corresponding complex MPSDataType"`, when input dtype is bfloat16 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166272 Approved by: https://github.com/kulinseth
This commit is contained in:
parent
1425b40f29
commit
add37bacda
|
|
@ -1,3 +1,5 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/SpectralOpsUtils.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
|
|
@ -37,25 +39,12 @@ NSArray<NSNumber*>* IntArrayToNSArray(IntArrayRef arr) {
|
|||
} // anonymous namespace
|
||||
|
||||
Tensor _fft_c2r_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
|
||||
TORCH_CHECK(self.is_complex());
|
||||
auto in_sizes = self.sizes();
|
||||
DimVector out_sizes(in_sizes.begin(), in_sizes.end());
|
||||
out_sizes[dim.back()] = last_dim_size;
|
||||
auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
|
||||
auto out = at::empty({}, self.options().dtype(c10::toRealValueType(self.scalar_type())));
|
||||
return _fft_c2r_mps_out(self, dim, normalization, last_dim_size, out);
|
||||
}
|
||||
|
||||
Tensor _fft_r2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
|
||||
TORCH_CHECK(self.is_floating_point());
|
||||
auto input_sizes = self.sizes();
|
||||
DimVector out_sizes(input_sizes.begin(), input_sizes.end());
|
||||
auto last_dim = dim.back();
|
||||
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
|
||||
if (onesided) {
|
||||
out_sizes[last_dim] = last_dim_halfsize;
|
||||
}
|
||||
|
||||
auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
|
||||
auto out = at::empty({}, self.options().dtype(c10::toComplexType(self.scalar_type())));
|
||||
return _fft_r2c_mps_out(self, dim, normalization, onesided, out);
|
||||
}
|
||||
|
||||
|
|
@ -72,6 +61,17 @@ using namespace mps;
|
|||
|
||||
// TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237
|
||||
Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) {
|
||||
TORCH_CHECK(self.scalar_type() == kFloat || self.scalar_type() == kHalf, "Only float and half dtypes are supported");
|
||||
TORCH_CHECK(out.scalar_type() == c10::toComplexType(self.scalar_type()));
|
||||
const auto input_sizes = self.sym_sizes();
|
||||
SymDimVector out_sizes(input_sizes.begin(), input_sizes.end());
|
||||
auto last_dim = dim.back();
|
||||
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
|
||||
if (onesided) {
|
||||
out_sizes[last_dim] = last_dim_halfsize;
|
||||
}
|
||||
at::native::resize_output_symint(out, out_sizes);
|
||||
|
||||
auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" +
|
||||
std::to_string(normalization) + ":" + std::to_string(onesided);
|
||||
@autoreleasepool {
|
||||
|
|
@ -112,6 +112,12 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
|
|||
int64_t normalization,
|
||||
int64_t last_dim_size,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK(self.is_complex(), "Input must be complex");
|
||||
TORCH_CHECK(out.scalar_type() == c10::toRealValueType(self.scalar_type()), "Unexpected output type");
|
||||
const auto in_sizes = self.sym_sizes();
|
||||
SymDimVector out_sizes(in_sizes.begin(), in_sizes.end());
|
||||
out_sizes[dim.back()] = last_dim_size;
|
||||
at::native::resize_output_symint(out, out_sizes);
|
||||
auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
|
||||
std::to_string(normalization) + ":" + std::to_string(last_dim_size);
|
||||
@autoreleasepool {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user