from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches class AlwaysCheckCudaLaunchTest(TestCase): def test_check_code(self): """Verifies that the regex works for a few different situations""" # Try some different spacings self.assertEqual(2, check_code_for_cuda_kernel_launches(""" some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); C10_CUDA_KERNEL_LAUNCH_CHECK(); some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); C10_CUDA_KERNEL_LAUNCH_CHECK(); some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); some_other_stuff; some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); C10_CUDA_KERNEL_LAUNCH_CHECK(); some_function_call<<<1,2,0,stream>>> (arg1,arg2,arg3); C10_CUDA_KERNEL_LAUNCH_CHECK(); some_function_call<<<1,2,0,stream>>> ( arg1 , arg2 , arg3 ) ; C10_CUDA_KERNEL_LAUNCH_CHECK(); """)) # Does it work for macros? self.assertEqual(0, check_code_for_cuda_kernel_launches(r""" #define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \ C10_CUDA_KERNEL_LAUNCH_CHECK(); #define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ indexAddSmallIndex \ <<>>( \ selfInfo, sourceInfo, indexInfo, \ selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); """)) def test_check_cuda_launches(self): unsafeLaunchesCount = check_cuda_kernel_launches() self.assertTrue(unsafeLaunchesCount == 0) if __name__ == '__main__': run_tests()