mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add complex32 dtype support to CPU/GPU implementation of (#45339)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45339 Test Plan: Imported from OSS GPU implementation already works as-is: $ python -c "import torch; a = torch.tensor([1j], dtype=torch.complex32, device=torch.device('cuda')); b = a.clone(); print(b); print(a)" tensor([0.+1.j], device='cuda:0', dtype=torch.complex32) tensor([0.+1.j], device='cuda:0', dtype=torch.complex32) Test for CPU implementation: $ python -c "import torch; a = torch.tensor([1j], dtype=torch.complex32); b = a.clone(); print(b); print(a)" tensor([0.+1.j], dtype=torch.complex32) tensor([0.+1.j], dtype=torch.complex32) Reviewed By: malfet Differential Revision: D23932649 Pulled By: soulitzer fbshipit-source-id: 394b6e1f3d462ee8a010f56f4bb8404af92a066b
This commit is contained in:
parent
7d4f5060ad
commit
735d5b8907
|
|
@ -17,6 +17,8 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
|
|||
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
|
||||
} else if (dtype == ScalarType::BFloat16) {
|
||||
cpu_kernel(iter, [=](at::BFloat16 a) -> at::BFloat16 { return a; });
|
||||
} else if (dtype == ScalarType::ComplexHalf) {
|
||||
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
|
||||
} else if (isQIntType(dtype)) {
|
||||
AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
|
||||
cpu_kernel_vec(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user