mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Silu support Complex for CUDA (#106854)
Fixes #89382 Silu support Complex for CUDA Pull Request resolved: https://github.com/pytorch/pytorch/pull/106854 Approved by: https://github.com/albanD
This commit is contained in:
parent
3ddf30505f
commit
871d7d242d
|
|
@ -15,12 +15,13 @@
|
|||
#include <ATen/cuda/ApplyGridUtils.cuh>
|
||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace {
|
||||
|
||||
void silu_kernel(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
iter.dtype(),
|
||||
|
|
@ -29,7 +30,7 @@ void silu_kernel(TensorIteratorBase& iter) {
|
|||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
const opmath_t x_acc = static_cast<opmath_t>(x);
|
||||
return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc));
|
||||
return x_acc / (opmath_t(1) + ::exp(-x_acc));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1054,6 +1054,30 @@ class TestUnaryUfuncs(TestCase):
|
|||
rtol=rtol,
|
||||
)
|
||||
|
||||
@dtypes(torch.complex64, torch.complex128)
|
||||
def test_silu_complex(self, device, dtype):
|
||||
atol = 1e-6
|
||||
rtol = 1e-6
|
||||
inouts = [
|
||||
(0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j),
|
||||
(1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
|
||||
(-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
|
||||
(0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
|
||||
(2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
|
||||
]
|
||||
|
||||
for inp, out in inouts:
|
||||
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
|
||||
self.assertFalse(torch.any(torch.isnan(res)))
|
||||
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
|
||||
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
|
||||
|
||||
for inp, out in inouts:
|
||||
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
|
||||
self.assertFalse(torch.any(torch.isnan(res)))
|
||||
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
|
||||
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
|
||||
|
||||
# It is not obvious how to merge this into OpInfo becuase these inputs
|
||||
# succeed for gradcheck but are expected to fail for gradgradcheck
|
||||
@dtypes(torch.double)
|
||||
|
|
|
|||
|
|
@ -13528,7 +13528,7 @@ op_db: List[OpInfo] = [
|
|||
ref=lambda x, inplace=False:
|
||||
x / (1 + np.exp(-x)),
|
||||
dtypes=complex_types(),
|
||||
dtypesIfCUDA=empty_types(),
|
||||
dtypesIfCUDA=complex_types(),
|
||||
supports_forward_ad=False,
|
||||
supports_autograd=False,
|
||||
assert_autodiffed=False,
|
||||
|
|
@ -13544,7 +13544,7 @@ op_db: List[OpInfo] = [
|
|||
), ],
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
|
||||
dtypes=(torch.cfloat,), device_type='cpu'),
|
||||
dtypes=(torch.cfloat,)),
|
||||
# FIXME: intentionally misreports dtypes
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
|
||||
# FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user