From a732bbea232fa32191f259d7cb15e9fabb6c2926 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 13 Jun 2022 01:56:42 +0000 Subject: [PATCH] [meta] Add meta support for fft ops (#79311) As per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/79311 Approved by: https://github.com/ezyang --- test/test_meta.py | 28 -------------------------- torch/_meta_registrations.py | 39 ++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/test/test_meta.py b/test/test_meta.py index aa6b3f1d06b..c5c1893ab0d 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -378,32 +378,15 @@ meta_function_expected_failures = { torch.bucketize: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::bucketize.Tensor, aten::bucketize.Tensor_out torch.combinations: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::masked_select torch.complex: {f16, f32, f64}, # aten::complex.out - torch.conj_physical: {c32}, # aten::conj_physical.out torch.corrcoef: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense torch.count_nonzero: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::count_nonzero.dim_IntList torch.cov: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense - torch.fft.fft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c - torch.fft.fft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c - torch.fft.fftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c torch.fft.hfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c torch.fft.hfft: {b8, f32, f64, i16, i32, i64, i8, u8}, torch.fft.hfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c - torch.fft.ifft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c - torch.fft.ifft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c - torch.fft.ifftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c - torch.fft.ihfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c - torch.fft.ihfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c - torch.fft.ihfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c - torch.fft.irfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out - torch.fft.irfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out - torch.fft.irfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out - torch.fft.rfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c - torch.fft.rfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c - torch.fft.rfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c torch.floor_divide: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::floor_divide, aten::floor_divide.out torch.frexp: {bf16, f16, f32, f64}, # aten::frexp.Tensor_out torch.functional.istft: {f32, f64}, # aten::view_as_complex - torch.functional.stft: {f32, f64}, # aten::_fft_r2c torch.functional.unique: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_unique2, aten::unique_dim torch.functional.unique_consecutive: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::unique_consecutive torch.histc: {bf16, f32, f64}, # aten::histc, aten::histc.out @@ -480,7 +463,6 @@ sys.exit() meta_function_skips = { torch.aminmax: {b8, f32, f64, i16, i32, i64, i8, u8}, - torch.conj_physical: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, torch.cummax: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, torch.cummin: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, torch.diff: {b8}, @@ -618,10 +600,8 @@ aten = torch.ops.aten # these always fail meta_dispatch_expected_failures = { - aten._conj_physical.default: {c32}, aten._convolution.default: {c64, i64, f64, c128, bf16, f32}, aten._ctc_loss.default: {f64, f32}, - aten._fft_r2c.default: {i64, u8, b8, f32, i8, f64, i16, i32}, aten._histogramdd_bin_edges.default: {f64, f32}, aten._histogramdd_from_bin_cts.default: {f64, f32}, aten._histogramdd_from_bin_tensors.default: {f64, f32}, @@ -634,7 +614,6 @@ meta_dispatch_expected_failures = { aten.col2im.default: {c64, f32, f64, c128}, aten.complex.default: {c64, f64, c128, f16, f32}, aten.complex.out: {f16}, - aten.conj_physical.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, c32, i32}, aten.convolution.default: {c64, i64, f64, c128, bf16, f32}, aten.count_nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, aten.count_nonzero.dim_IntList: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, @@ -732,14 +711,7 @@ meta_dispatch_device_expected_failures = defaultdict(dict) meta_dispatch_device_skips = defaultdict(dict) meta_dispatch_device_expected_failures['cuda'] = { - aten._conj_physical.default: {f16}, # aten::conj_physical.out aten._convolution.default: {f16, c32}, - aten._fft_c2c.default: {c32, f16}, # aten::_fft_c2c - aten._fft_c2c.out: {c32, f16}, # aten::_fft_c2c.out - aten._fft_c2r.default: {c32, f16}, # aten::_fft_c2r - aten._fft_c2r.out: {c32, f16}, # aten::_fft_c2r.out - aten._fft_r2c.default: {f16}, # aten::_fft_r2c - aten._fft_r2c.out: {f16}, # aten::_fft_r2c.out aten._unique2.default: {f16}, # aten::_unique2 aten._use_cudnn_ctc_loss.default: {f32, f64}, # aten::_use_cudnn_ctc_loss aten.convolution.default: {f16, c32}, diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 88f3ad043d1..828403aca54 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -22,6 +22,45 @@ def toRealValueType(dtype): return from_complex.get(dtype, dtype) +@torch.library.impl(meta_lib, "_fft_c2c") +def meta_fft_c2c(self, dim, normalization, forward): + assert self.dtype.is_complex + return self.new_empty(self.size()) + + +@torch.library.impl(meta_lib, "_fft_r2c") +def meta_fft_r2c(self, dim, normalization, onesided): + assert self.dtype.is_floating_point + output_sizes = list(self.size()) + + if onesided: + last_dim = dim[-1] + last_dim_halfsize = (output_sizes[last_dim] // 2) + 1 + output_sizes[last_dim] = last_dim_halfsize + + return self.new_empty( + output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) + ) + + +@out_wrapper +def meta_fft_c2r(self, dim, normalization, lastdim): + assert self.dtype.is_complex + output_sizes = list(self.size()) + output_sizes[dim[-1]] = lastdim + return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype)) + + +torch.library.impl(meta_lib, "_fft_c2r")(meta_fft_c2r) +torch.library.impl(meta_lib, "_fft_c2r.out")(meta_fft_c2r) + + +@torch.library.impl(meta_lib, "conj_physical.out") +def meta_conj_physical_out(self, out): + torch._resize_output_(out, self.size(), self.device) + return out.copy_(self) + + # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py @torch.library.impl(meta_lib, "index_select") def meta_index_select(self, dim, index):