disable jiterator for complex tan and tanh (#165250)

Fixes #100842

Disable jiterator for complex tan and tanh kernels due to accuracy issues, matching the existing approach used for acos, acosh, asin, and asinh. Reverts to thrust implementation which provides better numerical accuracy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165250
Approved by: https://github.com/ezyang
This commit is contained in:
Amin Sedaghat 2025-10-29 04:58:58 +00:00 committed by PyTorch MergeBot
parent cde81e92b9
commit 17d5aa4767
3 changed files with 48 additions and 4 deletions

View File

@ -12,14 +12,15 @@
namespace at::native {
#if AT_USE_JITERATOR()
#if 0 && AT_USE_JITERATOR()
constexpr char tan_name[] = "tan_impl";
#endif
void tan_kernel_cuda(TensorIteratorBase& iter) {
auto common_dtype = iter.common_dtype();
if (at::isComplexType(common_dtype)) {
#if AT_USE_JITERATOR()
// Disabled due to accuracy issues
#if 0 && AT_USE_JITERATOR()
static const auto tan_string = jiterator_stringify(
template <typename T> T tan_impl(T a) { return std::tan(a); });
AT_DISPATCH_COMPLEX_TYPES_AND(

View File

@ -12,14 +12,15 @@
namespace at::native {
#if AT_USE_JITERATOR()
#if 0 && AT_USE_JITERATOR()
constexpr char tanh_name[] = "tanh_impl";
#endif
void tanh_kernel_cuda(TensorIteratorBase& iter) {
auto common_dtype = iter.common_dtype();
if (at::isComplexType(common_dtype)) {
#if AT_USE_JITERATOR()
// Disabled due to accuracy issues
#if 0 && AT_USE_JITERATOR()
static const auto tanh_string = jiterator_stringify(
template <typename T> T tanh_impl(T a) { return std::tanh(a); });
AT_DISPATCH_COMPLEX_TYPES_AND(

View File

@ -773,6 +773,48 @@ class TestUnaryUfuncs(TestCase):
with self.assertRaises(AttributeError):
torch_inplace_method = getattr(torch.Tensor, fn_name + "_")
@onlyCUDA
@dtypes(torch.complex64)
def test_tan_complex_cuda_matches_numpy(self, device, dtype):
# Focused accuracy check for complex tan on CUDA against NumPy reference
# Includes values near tan singularities on the real axis
eps = 1e-3
specials = torch.tensor(
[
math.pi / 2 - eps,
math.pi / 2 + eps,
-math.pi / 2 - eps,
-math.pi / 2 + eps,
],
device=device,
dtype=torch.float32,
)
real = torch.randn(1024, device=device, dtype=torch.float32) * (2 * math.pi)
imag = torch.randn(1024, device=device, dtype=torch.float32) * 5.0
real = torch.cat([real, specials])
imag = torch.cat(
[
imag,
torch.linspace(
-3,
3,
steps=specials.numel(),
device=device,
),
]
)
z = torch.complex(real, imag).to(dtype)
self.compare_with_numpy(torch.tan, np.tan, z)
@onlyCUDA
@dtypes(torch.complex64)
def test_tanh_complex_cuda_matches_numpy(self, device, dtype):
# Focused accuracy check for complex tanh on CUDA against NumPy reference
real = torch.randn(2048, device=device, dtype=torch.float32) * (2 * math.pi)
imag = torch.randn(2048, device=device, dtype=torch.float32) * 5.0
z = torch.complex(real, imag).to(dtype)
self.compare_with_numpy(torch.tanh, np.tanh, z)
def check_internal_mem_overlap(
self, inplace_op, num_inputs, dtype, device, expected_failure=False
):