[meta] Add meta support for fft ops (#79311)

As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79311
Approved by: https://github.com/ezyang
This commit is contained in:
kshitij12345 2022-06-13 01:56:42 +00:00 committed by PyTorch MergeBot
parent bd1a35dfc8
commit a732bbea23
2 changed files with 39 additions and 28 deletions

View File

@ -378,32 +378,15 @@ meta_function_expected_failures = {
torch.bucketize: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::bucketize.Tensor, aten::bucketize.Tensor_out
torch.combinations: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::masked_select
torch.complex: {f16, f32, f64}, # aten::complex.out
torch.conj_physical: {c32}, # aten::conj_physical.out
torch.corrcoef: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense
torch.count_nonzero: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::count_nonzero.dim_IntList
torch.cov: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense
torch.fft.fft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c
torch.fft.fft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c
torch.fft.fftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c
torch.fft.hfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c
torch.fft.hfft: {b8, f32, f64, i16, i32, i64, i8, u8},
torch.fft.hfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c
torch.fft.ifft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c
torch.fft.ifft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c
torch.fft.ifftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c
torch.fft.ihfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c
torch.fft.ihfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c
torch.fft.ihfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c
torch.fft.irfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out
torch.fft.irfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out
torch.fft.irfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2r, aten::_fft_c2r.out
torch.fft.rfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c
torch.fft.rfft: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c
torch.fft.rfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_r2c
torch.floor_divide: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::floor_divide, aten::floor_divide.out
torch.frexp: {bf16, f16, f32, f64}, # aten::frexp.Tensor_out
torch.functional.istft: {f32, f64}, # aten::view_as_complex
torch.functional.stft: {f32, f64}, # aten::_fft_r2c
torch.functional.unique: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_unique2, aten::unique_dim
torch.functional.unique_consecutive: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::unique_consecutive
torch.histc: {bf16, f32, f64}, # aten::histc, aten::histc.out
@ -480,7 +463,6 @@ sys.exit()
meta_function_skips = {
torch.aminmax: {b8, f32, f64, i16, i32, i64, i8, u8},
torch.conj_physical: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8},
torch.cummax: {b8, bf16, f32, f64, i16, i32, i64, i8, u8},
torch.cummin: {b8, bf16, f32, f64, i16, i32, i64, i8, u8},
torch.diff: {b8},
@ -618,10 +600,8 @@ aten = torch.ops.aten
# these always fail
meta_dispatch_expected_failures = {
aten._conj_physical.default: {c32},
aten._convolution.default: {c64, i64, f64, c128, bf16, f32},
aten._ctc_loss.default: {f64, f32},
aten._fft_r2c.default: {i64, u8, b8, f32, i8, f64, i16, i32},
aten._histogramdd_bin_edges.default: {f64, f32},
aten._histogramdd_from_bin_cts.default: {f64, f32},
aten._histogramdd_from_bin_tensors.default: {f64, f32},
@ -634,7 +614,6 @@ meta_dispatch_expected_failures = {
aten.col2im.default: {c64, f32, f64, c128},
aten.complex.default: {c64, f64, c128, f16, f32},
aten.complex.out: {f16},
aten.conj_physical.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, c32, i32},
aten.convolution.default: {c64, i64, f64, c128, bf16, f32},
aten.count_nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.count_nonzero.dim_IntList: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
@ -732,14 +711,7 @@ meta_dispatch_device_expected_failures = defaultdict(dict)
meta_dispatch_device_skips = defaultdict(dict)
meta_dispatch_device_expected_failures['cuda'] = {
aten._conj_physical.default: {f16}, # aten::conj_physical.out
aten._convolution.default: {f16, c32},
aten._fft_c2c.default: {c32, f16}, # aten::_fft_c2c
aten._fft_c2c.out: {c32, f16}, # aten::_fft_c2c.out
aten._fft_c2r.default: {c32, f16}, # aten::_fft_c2r
aten._fft_c2r.out: {c32, f16}, # aten::_fft_c2r.out
aten._fft_r2c.default: {f16}, # aten::_fft_r2c
aten._fft_r2c.out: {f16}, # aten::_fft_r2c.out
aten._unique2.default: {f16}, # aten::_unique2
aten._use_cudnn_ctc_loss.default: {f32, f64}, # aten::_use_cudnn_ctc_loss
aten.convolution.default: {f16, c32},

View File

@ -22,6 +22,45 @@ def toRealValueType(dtype):
return from_complex.get(dtype, dtype)
@torch.library.impl(meta_lib, "_fft_c2c")
def meta_fft_c2c(self, dim, normalization, forward):
assert self.dtype.is_complex
return self.new_empty(self.size())
@torch.library.impl(meta_lib, "_fft_r2c")
def meta_fft_r2c(self, dim, normalization, onesided):
assert self.dtype.is_floating_point
output_sizes = list(self.size())
if onesided:
last_dim = dim[-1]
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
output_sizes[last_dim] = last_dim_halfsize
return self.new_empty(
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
)
@out_wrapper
def meta_fft_c2r(self, dim, normalization, lastdim):
assert self.dtype.is_complex
output_sizes = list(self.size())
output_sizes[dim[-1]] = lastdim
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
torch.library.impl(meta_lib, "_fft_c2r")(meta_fft_c2r)
torch.library.impl(meta_lib, "_fft_c2r.out")(meta_fft_c2r)
@torch.library.impl(meta_lib, "conj_physical.out")
def meta_conj_physical_out(self, out):
torch._resize_output_(out, self.size(), self.device)
return out.copy_(self)
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
@torch.library.impl(meta_lib, "index_select")
def meta_index_select(self, dim, index):