diff --git a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu index e9b495ca70b..82096b96dbb 100644 --- a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu @@ -15,12 +15,13 @@ #include #include #include +#include namespace at::native { namespace { void silu_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), @@ -29,7 +30,7 @@ void silu_kernel(TensorIteratorBase& iter) { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { using opmath_t = at::opmath_type; const opmath_t x_acc = static_cast(x); - return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); + return x_acc / (opmath_t(1) + ::exp(-x_acc)); }); }); } diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 76a2c9ebf5f..6c6744b6b78 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1054,6 +1054,30 @@ class TestUnaryUfuncs(TestCase): rtol=rtol, ) + @dtypes(torch.complex64, torch.complex128) + def test_silu_complex(self, device, dtype): + atol = 1e-6 + rtol = 1e-6 + inouts = [ + (0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j), + (1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j), + (-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j), + (0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j), + (2.0j, -1.55740761756896972656 + 0.99999988079071044922j) + ] + + for inp, out in inouts: + res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device)) + self.assertFalse(torch.any(torch.isnan(res))) + self.assertEqual(res.real, out.real, atol=atol, rtol=rtol) + self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol) + + for inp, out in inouts: + res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True) + self.assertFalse(torch.any(torch.isnan(res))) + self.assertEqual(res.real, out.real, atol=atol, rtol=rtol) + self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol) + # It is not obvious how to merge this into OpInfo becuase these inputs # succeed for gradcheck but are expected to fail for gradgradcheck @dtypes(torch.double) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 735f6c8ad29..977536e5876 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -13528,7 +13528,7 @@ op_db: List[OpInfo] = [ ref=lambda x, inplace=False: x / (1 + np.exp(-x)), dtypes=complex_types(), - dtypesIfCUDA=empty_types(), + dtypesIfCUDA=complex_types(), supports_forward_ad=False, supports_autograd=False, assert_autodiffed=False, @@ -13544,7 +13544,7 @@ op_db: List[OpInfo] = [ ), ], skips=( DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', - dtypes=(torch.cfloat,), device_type='cpu'), + dtypes=(torch.cfloat,)), # FIXME: intentionally misreports dtypes DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j)