diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu index 34e055d589a..bedb5add839 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu @@ -12,14 +12,15 @@ namespace at::native { -#if AT_USE_JITERATOR() +#if 0 && AT_USE_JITERATOR() constexpr char tan_name[] = "tan_impl"; #endif void tan_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR() + // Disabled due to accuracy issues +#if 0 && AT_USE_JITERATOR() static const auto tan_string = jiterator_stringify( template T tan_impl(T a) { return std::tan(a); }); AT_DISPATCH_COMPLEX_TYPES_AND( diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu index 61393eec8ca..dedb15473fc 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu @@ -12,14 +12,15 @@ namespace at::native { -#if AT_USE_JITERATOR() +#if 0 && AT_USE_JITERATOR() constexpr char tanh_name[] = "tanh_impl"; #endif void tanh_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR() + // Disabled due to accuracy issues +#if 0 && AT_USE_JITERATOR() static const auto tanh_string = jiterator_stringify( template T tanh_impl(T a) { return std::tanh(a); }); AT_DISPATCH_COMPLEX_TYPES_AND( diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 15b967e5707..13f205c4d11 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -773,6 +773,48 @@ class TestUnaryUfuncs(TestCase): with self.assertRaises(AttributeError): torch_inplace_method = getattr(torch.Tensor, fn_name + "_") + @onlyCUDA + @dtypes(torch.complex64) + def test_tan_complex_cuda_matches_numpy(self, device, dtype): + # Focused accuracy check for complex tan on CUDA against NumPy reference + # Includes values near tan singularities on the real axis + eps = 1e-3 + specials = torch.tensor( + [ + math.pi / 2 - eps, + math.pi / 2 + eps, + -math.pi / 2 - eps, + -math.pi / 2 + eps, + ], + device=device, + dtype=torch.float32, + ) + real = torch.randn(1024, device=device, dtype=torch.float32) * (2 * math.pi) + imag = torch.randn(1024, device=device, dtype=torch.float32) * 5.0 + real = torch.cat([real, specials]) + imag = torch.cat( + [ + imag, + torch.linspace( + -3, + 3, + steps=specials.numel(), + device=device, + ), + ] + ) + z = torch.complex(real, imag).to(dtype) + self.compare_with_numpy(torch.tan, np.tan, z) + + @onlyCUDA + @dtypes(torch.complex64) + def test_tanh_complex_cuda_matches_numpy(self, device, dtype): + # Focused accuracy check for complex tanh on CUDA against NumPy reference + real = torch.randn(2048, device=device, dtype=torch.float32) * (2 * math.pi) + imag = torch.randn(2048, device=device, dtype=torch.float32) * 5.0 + z = torch.complex(real, imag).to(dtype) + self.compare_with_numpy(torch.tanh, np.tanh, z) + def check_internal_mem_overlap( self, inplace_op, num_inputs, dtype, device, expected_failure=False ):