diff --git a/aten/src/ATen/native/cuda/Sort.cpp b/aten/src/ATen/native/cuda/Sort.cpp index 4605be8cdf1..39581cef25c 100644 --- a/aten/src/ATen/native/cuda/Sort.cpp +++ b/aten/src/ATen/native/cuda/Sort.cpp @@ -63,9 +63,6 @@ void sort_cuda_kernel( "The dimension being sorted can not have more than INT_MAX elements."); const auto self_dtype = self.dtype(); - // FIXME: remove this check once cub sort supports bool - TORCH_CHECK(self_dtype != ScalarType::Bool, - "Sort currently does not support bool dtype on CUDA."); TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble, "Sort currently does not support complex dtypes on CUDA."); diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index aebfdaec0cb..6d37607ffbf 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -193,8 +193,7 @@ class TestSortAndSelect(TestCase): self.assertEqual(res1val, res1val_cpu.cuda()) self.assertEqual(res1ind, res1ind_cpu.cuda()) - # FIXME: remove torch.bool from unsupported types once support is added for cub sort - @dtypes(*all_types_and(torch.half, torch.bfloat16)) + @dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16)) def test_stable_sort(self, device, dtype): sizes = (100, 1000, 10000) for ncopies in sizes: @@ -323,8 +322,7 @@ class TestSortAndSelect(TestCase): self.assertEqual(indices, indices_cont) self.assertEqual(values, values_cont) - # FIXME: remove torch.bool from unsupported types once support is added for cub sort - @dtypes(*all_types_and(torch.half, torch.bfloat16)) + @dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16)) def test_stable_sort_against_numpy(self, device, dtype): if dtype in floating_types_and(torch.float16, torch.bfloat16): inf = float("inf") diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b9ec84b8cd9..d52a6474977 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3320,7 +3320,10 @@ def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs): flag = [True, False] for dim, descending, stable in product(dims, flag, flag): # default schema without stable sort - yield SampleInput(small_3d_unique(), dim, descending) + if not (dtype == torch.bool and torch.device(device).type == 'cuda'): + # bool and cuda requires stable sort for stable results, at least + # for the return index + yield SampleInput(small_3d_unique(), dim, descending) # schema with stable sort, no CUDA support yet if torch.device(device).type == 'cpu': yield SampleInput( @@ -18477,11 +18480,13 @@ op_db: List[OpInfo] = [ )), OpInfo('sort', dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_sort, supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], device_type='cuda'), )), OpInfo('unique', dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), @@ -19506,12 +19511,14 @@ op_db: List[OpInfo] = [ sample_inputs_func=sample_inputs_unfold), OpInfo('msort', dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), check_batched_gradgrad=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_msort, skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], device_type='cuda'), )), OpInfo('movedim', aliases=('moveaxis',), @@ -21324,7 +21331,7 @@ op_db: List[OpInfo] = [ OpInfo( "argsort", dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_sort, supports_out=False, supports_autograd=False, @@ -21335,6 +21342,13 @@ op_db: List[OpInfo] = [ "test_variant_consistency_jit", dtypes=(torch.float32,), ), + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_non_standard_bool_values", + dtypes=[torch.bool], + device_type='cuda', + ), ), ), OpInfo(