From 17d5aa47674c5195a3d4919169fb6a334598e2a6 Mon Sep 17 00:00:00 2001 From: Amin Sedaghat Date: Wed, 29 Oct 2025 04:58:58 +0000 Subject: [PATCH] disable jiterator for complex tan and tanh (#165250) Fixes #100842 Disable jiterator for complex tan and tanh kernels due to accuracy issues, matching the existing approach used for acos, acosh, asin, and asinh. Reverts to thrust implementation which provides better numerical accuracy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165250 Approved by: https://github.com/ezyang --- .../native/cuda/UnaryGeometricTanKernel.cu | 5 ++- .../native/cuda/UnaryGeometricTanhKernel.cu | 5 ++- test/test_unary_ufuncs.py | 42 +++++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu index 34e055d589a..bedb5add839 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu @@ -12,14 +12,15 @@ namespace at::native { -#if AT_USE_JITERATOR() +#if 0 && AT_USE_JITERATOR() constexpr char tan_name[] = "tan_impl"; #endif void tan_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR() + // Disabled due to accuracy issues +#if 0 && AT_USE_JITERATOR() static const auto tan_string = jiterator_stringify( template T tan_impl(T a) { return std::tan(a); }); AT_DISPATCH_COMPLEX_TYPES_AND( diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu index 61393eec8ca..dedb15473fc 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu @@ -12,14 +12,15 @@ namespace at::native { -#if AT_USE_JITERATOR() +#if 0 && AT_USE_JITERATOR() constexpr char tanh_name[] = "tanh_impl"; #endif void tanh_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR() + // Disabled due to accuracy issues +#if 0 && AT_USE_JITERATOR() static const auto tanh_string = jiterator_stringify( template T tanh_impl(T a) { return std::tanh(a); }); AT_DISPATCH_COMPLEX_TYPES_AND( diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 15b967e5707..13f205c4d11 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -773,6 +773,48 @@ class TestUnaryUfuncs(TestCase): with self.assertRaises(AttributeError): torch_inplace_method = getattr(torch.Tensor, fn_name + "_") + @onlyCUDA + @dtypes(torch.complex64) + def test_tan_complex_cuda_matches_numpy(self, device, dtype): + # Focused accuracy check for complex tan on CUDA against NumPy reference + # Includes values near tan singularities on the real axis + eps = 1e-3 + specials = torch.tensor( + [ + math.pi / 2 - eps, + math.pi / 2 + eps, + -math.pi / 2 - eps, + -math.pi / 2 + eps, + ], + device=device, + dtype=torch.float32, + ) + real = torch.randn(1024, device=device, dtype=torch.float32) * (2 * math.pi) + imag = torch.randn(1024, device=device, dtype=torch.float32) * 5.0 + real = torch.cat([real, specials]) + imag = torch.cat( + [ + imag, + torch.linspace( + -3, + 3, + steps=specials.numel(), + device=device, + ), + ] + ) + z = torch.complex(real, imag).to(dtype) + self.compare_with_numpy(torch.tan, np.tan, z) + + @onlyCUDA + @dtypes(torch.complex64) + def test_tanh_complex_cuda_matches_numpy(self, device, dtype): + # Focused accuracy check for complex tanh on CUDA against NumPy reference + real = torch.randn(2048, device=device, dtype=torch.float32) * (2 * math.pi) + imag = torch.randn(2048, device=device, dtype=torch.float32) * 5.0 + z = torch.complex(real, imag).to(dtype) + self.compare_with_numpy(torch.tanh, np.tanh, z) + def check_internal_mem_overlap( self, inplace_op, num_inputs, dtype, device, expected_failure=False ):