mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix C10_CUDA_CHECK for failing to capture last cuda error occasionally This error was accidentally introduced by #92227, which was trying to fix_ #91758 as introduced in #85256. The unit test `TestCuda.test_events_multi_gpu_elapsed_time` has been failed since that PR got merged (in cuda 11.8 and cuda 12.0). That test requires >=2 GPU, so it's probably not tested in the OSS CI? ``` python test/test_cuda.py -v -k TestCuda.test_events_multi_gpu_elapsed_time ``` E.g. in https://github.com/pytorch/pytorch/actions/runs/4026926691/jobs/6922406192 ``` 2023-01-27T19:41:32.2312162Z test_events_multi_gpu_elapsed_time (__main__.TestCuda) ... skip: detected only one GPU (0.001s) ``` The original C10_CUDA_CHECK before #85256 has an extra `cudaGetLastError` that captures those cuda errors, https://github.com/pytorch/pytorch/pull/85256/files#diff-0823e63e781acf56e93a5553ed7feee0db0bda05d86e2560c7b80e87e32e0024L41-L42 This extra `cudaGetLastError` was originally introduced in #17337. As commented here https://github.com/pytorch/pytorch/pull/17337/files#r259104503 > soumith on Feb 21, 2019: Without this, a previously raised error was still lingering and falsely being triggered for a subsequent CUDA call. colesbury suggested that this is the right thing to do. Pull Request resolved: https://github.com/pytorch/pytorch/pull/93192 Approved by: https://github.com/ezyang
49 lines
1.3 KiB
C++
49 lines
1.3 KiB
C++
#include <c10/cuda/CUDAException.h>
|
|
|
|
#include <c10/cuda/CUDADeviceAssertionHost.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <string>
|
|
|
|
namespace c10 {
|
|
namespace cuda {
|
|
|
|
void c10_cuda_check_implementation(
|
|
const int32_t err,
|
|
const char* filename,
|
|
const char* function_name,
|
|
const int line_number,
|
|
const bool include_device_assertions) {
|
|
const auto cuda_error = static_cast<cudaError_t>(err);
|
|
const auto cuda_kernel_failure = include_device_assertions
|
|
? c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().has_failed()
|
|
: false;
|
|
|
|
if (C10_LIKELY(cuda_error == cudaSuccess && !cuda_kernel_failure)) {
|
|
return;
|
|
}
|
|
|
|
auto error_unused C10_UNUSED = cudaGetLastError();
|
|
(void)error_unused;
|
|
|
|
std::string check_message;
|
|
#ifndef STRIP_ERROR_MESSAGES
|
|
check_message.append("CUDA error: ");
|
|
check_message.append(cudaGetErrorString(cuda_error));
|
|
check_message.append(c10::cuda::get_cuda_check_suffix());
|
|
check_message.append("\n");
|
|
if (include_device_assertions) {
|
|
check_message.append(c10_retrieve_device_side_assertion_info());
|
|
} else {
|
|
check_message.append(
|
|
"Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers.");
|
|
}
|
|
#endif
|
|
|
|
TORCH_CHECK(false, check_message);
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace c10
|