mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
disable jiterator for complex tan and tanh (#165250)
Fixes #100842 Disable jiterator for complex tan and tanh kernels due to accuracy issues, matching the existing approach used for acos, acosh, asin, and asinh. Reverts to thrust implementation which provides better numerical accuracy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165250 Approved by: https://github.com/ezyang
This commit is contained in:
parent
cde81e92b9
commit
17d5aa4767
|
|
@ -12,14 +12,15 @@
|
|||
|
||||
namespace at::native {
|
||||
|
||||
#if AT_USE_JITERATOR()
|
||||
#if 0 && AT_USE_JITERATOR()
|
||||
constexpr char tan_name[] = "tan_impl";
|
||||
#endif
|
||||
|
||||
void tan_kernel_cuda(TensorIteratorBase& iter) {
|
||||
auto common_dtype = iter.common_dtype();
|
||||
if (at::isComplexType(common_dtype)) {
|
||||
#if AT_USE_JITERATOR()
|
||||
// Disabled due to accuracy issues
|
||||
#if 0 && AT_USE_JITERATOR()
|
||||
static const auto tan_string = jiterator_stringify(
|
||||
template <typename T> T tan_impl(T a) { return std::tan(a); });
|
||||
AT_DISPATCH_COMPLEX_TYPES_AND(
|
||||
|
|
|
|||
|
|
@ -12,14 +12,15 @@
|
|||
|
||||
namespace at::native {
|
||||
|
||||
#if AT_USE_JITERATOR()
|
||||
#if 0 && AT_USE_JITERATOR()
|
||||
constexpr char tanh_name[] = "tanh_impl";
|
||||
#endif
|
||||
|
||||
void tanh_kernel_cuda(TensorIteratorBase& iter) {
|
||||
auto common_dtype = iter.common_dtype();
|
||||
if (at::isComplexType(common_dtype)) {
|
||||
#if AT_USE_JITERATOR()
|
||||
// Disabled due to accuracy issues
|
||||
#if 0 && AT_USE_JITERATOR()
|
||||
static const auto tanh_string = jiterator_stringify(
|
||||
template <typename T> T tanh_impl(T a) { return std::tanh(a); });
|
||||
AT_DISPATCH_COMPLEX_TYPES_AND(
|
||||
|
|
|
|||
|
|
@ -773,6 +773,48 @@ class TestUnaryUfuncs(TestCase):
|
|||
with self.assertRaises(AttributeError):
|
||||
torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.complex64)
|
||||
def test_tan_complex_cuda_matches_numpy(self, device, dtype):
|
||||
# Focused accuracy check for complex tan on CUDA against NumPy reference
|
||||
# Includes values near tan singularities on the real axis
|
||||
eps = 1e-3
|
||||
specials = torch.tensor(
|
||||
[
|
||||
math.pi / 2 - eps,
|
||||
math.pi / 2 + eps,
|
||||
-math.pi / 2 - eps,
|
||||
-math.pi / 2 + eps,
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
real = torch.randn(1024, device=device, dtype=torch.float32) * (2 * math.pi)
|
||||
imag = torch.randn(1024, device=device, dtype=torch.float32) * 5.0
|
||||
real = torch.cat([real, specials])
|
||||
imag = torch.cat(
|
||||
[
|
||||
imag,
|
||||
torch.linspace(
|
||||
-3,
|
||||
3,
|
||||
steps=specials.numel(),
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
)
|
||||
z = torch.complex(real, imag).to(dtype)
|
||||
self.compare_with_numpy(torch.tan, np.tan, z)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.complex64)
|
||||
def test_tanh_complex_cuda_matches_numpy(self, device, dtype):
|
||||
# Focused accuracy check for complex tanh on CUDA against NumPy reference
|
||||
real = torch.randn(2048, device=device, dtype=torch.float32) * (2 * math.pi)
|
||||
imag = torch.randn(2048, device=device, dtype=torch.float32) * 5.0
|
||||
z = torch.complex(real, imag).to(dtype)
|
||||
self.compare_with_numpy(torch.tanh, np.tanh, z)
|
||||
|
||||
def check_internal_mem_overlap(
|
||||
self, inplace_op, num_inputs, dtype, device, expected_failure=False
|
||||
):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user