mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Silu support Complex for CUDA (#106854)
Fixes #89382 Silu support Complex for CUDA Pull Request resolved: https://github.com/pytorch/pytorch/pull/106854 Approved by: https://github.com/albanD
This commit is contained in:
parent
3ddf30505f
commit
871d7d242d
|
|
@ -15,12 +15,13 @@
|
||||||
#include <ATen/cuda/ApplyGridUtils.cuh>
|
#include <ATen/cuda/ApplyGridUtils.cuh>
|
||||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||||
#include <ATen/native/cuda/Loops.cuh>
|
#include <ATen/native/cuda/Loops.cuh>
|
||||||
|
#include <c10/util/complex.h>
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void silu_kernel(TensorIteratorBase& iter) {
|
void silu_kernel(TensorIteratorBase& iter) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||||
at::ScalarType::Half,
|
at::ScalarType::Half,
|
||||||
at::ScalarType::BFloat16,
|
at::ScalarType::BFloat16,
|
||||||
iter.dtype(),
|
iter.dtype(),
|
||||||
|
|
@ -29,7 +30,7 @@ void silu_kernel(TensorIteratorBase& iter) {
|
||||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||||
using opmath_t = at::opmath_type<scalar_t>;
|
using opmath_t = at::opmath_type<scalar_t>;
|
||||||
const opmath_t x_acc = static_cast<opmath_t>(x);
|
const opmath_t x_acc = static_cast<opmath_t>(x);
|
||||||
return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc));
|
return x_acc / (opmath_t(1) + ::exp(-x_acc));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1054,6 +1054,30 @@ class TestUnaryUfuncs(TestCase):
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@dtypes(torch.complex64, torch.complex128)
|
||||||
|
def test_silu_complex(self, device, dtype):
|
||||||
|
atol = 1e-6
|
||||||
|
rtol = 1e-6
|
||||||
|
inouts = [
|
||||||
|
(0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j),
|
||||||
|
(1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
|
||||||
|
(-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
|
||||||
|
(0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
|
||||||
|
(2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
|
||||||
|
]
|
||||||
|
|
||||||
|
for inp, out in inouts:
|
||||||
|
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
|
||||||
|
self.assertFalse(torch.any(torch.isnan(res)))
|
||||||
|
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
|
||||||
|
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
for inp, out in inouts:
|
||||||
|
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
|
||||||
|
self.assertFalse(torch.any(torch.isnan(res)))
|
||||||
|
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
|
||||||
|
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
# It is not obvious how to merge this into OpInfo becuase these inputs
|
# It is not obvious how to merge this into OpInfo becuase these inputs
|
||||||
# succeed for gradcheck but are expected to fail for gradgradcheck
|
# succeed for gradcheck but are expected to fail for gradgradcheck
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
|
|
|
||||||
|
|
@ -13528,7 +13528,7 @@ op_db: List[OpInfo] = [
|
||||||
ref=lambda x, inplace=False:
|
ref=lambda x, inplace=False:
|
||||||
x / (1 + np.exp(-x)),
|
x / (1 + np.exp(-x)),
|
||||||
dtypes=complex_types(),
|
dtypes=complex_types(),
|
||||||
dtypesIfCUDA=empty_types(),
|
dtypesIfCUDA=complex_types(),
|
||||||
supports_forward_ad=False,
|
supports_forward_ad=False,
|
||||||
supports_autograd=False,
|
supports_autograd=False,
|
||||||
assert_autodiffed=False,
|
assert_autodiffed=False,
|
||||||
|
|
@ -13544,7 +13544,7 @@ op_db: List[OpInfo] = [
|
||||||
), ],
|
), ],
|
||||||
skips=(
|
skips=(
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
|
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
|
||||||
dtypes=(torch.cfloat,), device_type='cpu'),
|
dtypes=(torch.cfloat,)),
|
||||||
# FIXME: intentionally misreports dtypes
|
# FIXME: intentionally misreports dtypes
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
|
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
|
||||||
# FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j)
|
# FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user