Add complex32 dtype support to CPU/GPU implementation of (#45339)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45339

Test Plan:
Imported from OSS

GPU implementation already works as-is:
$ python -c "import torch; a = torch.tensor([1j], dtype=torch.complex32, device=torch.device('cuda')); b = a.clone(); print(b); print(a)"
tensor([0.+1.j], device='cuda:0', dtype=torch.complex32)
tensor([0.+1.j], device='cuda:0', dtype=torch.complex32)

Test for CPU implementation:
$ python -c "import torch; a = torch.tensor([1j], dtype=torch.complex32); b = a.clone(); print(b); print(a)"
tensor([0.+1.j], dtype=torch.complex32)
tensor([0.+1.j], dtype=torch.complex32)

Reviewed By: malfet

Differential Revision: D23932649

Pulled By: soulitzer

fbshipit-source-id: 394b6e1f3d462ee8a010f56f4bb8404af92a066b
This commit is contained in:
Jeffrey Wan 2020-10-08 09:27:09 -07:00 committed by Facebook GitHub Bot
parent 7d4f5060ad
commit 735d5b8907

View File

@ -17,6 +17,8 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
} else if (dtype == ScalarType::BFloat16) {
cpu_kernel(iter, [=](at::BFloat16 a) -> at::BFloat16 { return a; });
} else if (dtype == ScalarType::ComplexHalf) {
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
} else if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(