Silu support Complex for CUDA (#106854)

Fixes #89382

Silu support Complex for CUDA
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106854
Approved by: https://github.com/albanD
This commit is contained in:
FFFrog 2023-08-19 06:57:05 +00:00 committed by PyTorch MergeBot
parent 3ddf30505f
commit 871d7d242d
3 changed files with 29 additions and 4 deletions

View File

@ -15,12 +15,13 @@
#include <ATen/cuda/ApplyGridUtils.cuh> #include <ATen/cuda/ApplyGridUtils.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh> #include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/Loops.cuh> #include <ATen/native/cuda/Loops.cuh>
#include <c10/util/complex.h>
namespace at::native { namespace at::native {
namespace { namespace {
void silu_kernel(TensorIteratorBase& iter) { void silu_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::Half,
at::ScalarType::BFloat16, at::ScalarType::BFloat16,
iter.dtype(), iter.dtype(),
@ -29,7 +30,7 @@ void silu_kernel(TensorIteratorBase& iter) {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
using opmath_t = at::opmath_type<scalar_t>; using opmath_t = at::opmath_type<scalar_t>;
const opmath_t x_acc = static_cast<opmath_t>(x); const opmath_t x_acc = static_cast<opmath_t>(x);
return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); return x_acc / (opmath_t(1) + ::exp(-x_acc));
}); });
}); });
} }

View File

@ -1054,6 +1054,30 @@ class TestUnaryUfuncs(TestCase):
rtol=rtol, rtol=rtol,
) )
@dtypes(torch.complex64, torch.complex128)
def test_silu_complex(self, device, dtype):
atol = 1e-6
rtol = 1e-6
inouts = [
(0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j),
(1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
(-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
(0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
(2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
]
for inp, out in inouts:
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
self.assertFalse(torch.any(torch.isnan(res)))
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
for inp, out in inouts:
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
self.assertFalse(torch.any(torch.isnan(res)))
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
# It is not obvious how to merge this into OpInfo becuase these inputs # It is not obvious how to merge this into OpInfo becuase these inputs
# succeed for gradcheck but are expected to fail for gradgradcheck # succeed for gradcheck but are expected to fail for gradgradcheck
@dtypes(torch.double) @dtypes(torch.double)

View File

@ -13528,7 +13528,7 @@ op_db: List[OpInfo] = [
ref=lambda x, inplace=False: ref=lambda x, inplace=False:
x / (1 + np.exp(-x)), x / (1 + np.exp(-x)),
dtypes=complex_types(), dtypes=complex_types(),
dtypesIfCUDA=empty_types(), dtypesIfCUDA=complex_types(),
supports_forward_ad=False, supports_forward_ad=False,
supports_autograd=False, supports_autograd=False,
assert_autodiffed=False, assert_autodiffed=False,
@ -13544,7 +13544,7 @@ op_db: List[OpInfo] = [
), ], ), ],
skips=( skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
dtypes=(torch.cfloat,), device_type='cpu'), dtypes=(torch.cfloat,)),
# FIXME: intentionally misreports dtypes # FIXME: intentionally misreports dtypes
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
# FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j) # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j)