mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
complex.pow(2) on GPU by replacing with complex * complex to avoid numerical instability (#152373)
Fixes #150951 Summary: For complex.pow(2) on GPU: Uses complex * complex directly. Produces results consistent with CPU implementation. Eliminates spurious imaginary components for real inputs. 🧪 Tests Added unit tests to verify correctness of the new kernel path. Verified numerical consistency with CPU results. This change is backward-compatible and only affects the specific case of pow(2) on complex tensors on GPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152373 Approved by: https://github.com/ezyang
This commit is contained in:
parent
e290a4c645
commit
382c6190c1
|
|
@ -185,6 +185,12 @@ void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar
|
|||
return;
|
||||
}
|
||||
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_cuda", [&]() {
|
||||
if (exp_scalar.equal(2.0)) {
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t {
|
||||
return base * base;
|
||||
});
|
||||
return;
|
||||
}
|
||||
const auto exp = exp_scalar.to<scalar_t>();
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t {
|
||||
return pow_(base, exp);
|
||||
|
|
|
|||
|
|
@ -1688,12 +1688,11 @@ class TestBinaryUfuncs(TestCase):
|
|||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.complex64, torch.complex128)
|
||||
def test_pow_cuda_complex_extremal_failing(self, device, dtype):
|
||||
def test_pow_cuda_complex_extremal_passing(self, device, dtype):
|
||||
t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device)
|
||||
with self.assertRaises(AssertionError):
|
||||
cuda_out = t.pow(2)
|
||||
cpu_out = t.cpu().pow(2)
|
||||
self.assertEqual(cpu_out, cuda_out)
|
||||
cuda_out = t.pow(2)
|
||||
cpu_out = t.cpu().pow(2)
|
||||
self.assertEqual(cpu_out, cuda_out)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@onlyNativeDeviceTypes
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user