From 2793cf85ece134f37b6726dbc643814d2b8bce17 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 30 Mar 2022 10:07:52 -0700 Subject: [PATCH] Check all CUDA API calls for errors in caffe2/c10/ (#74918) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74918 Test Plan: Sandcastle Reviewed By: ngimel Differential Revision: D35194795 fbshipit-source-id: 8490e5497c37bab0055925ed520c2fd0c37a554c (cherry picked from commit 52697ab670e2f53c580cfd4ca82c5468ed3bb06c) --- c10/cuda/CUDACachingAllocator.cpp | 14 +++++++------- c10/cuda/CUDAFunctions.cpp | 7 ++----- c10/cuda/CUDAStream.h | 2 +- c10/cuda/impl/CUDAGuardImpl.h | 4 ++-- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index b6368c41c28..49e7f3c3d13 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -315,7 +315,7 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) { if (at::cuda::currentStreamCaptureStatusMayInitCtx() == at::cuda::CaptureStatus::None) { #endif - return cudaMalloc(p, size); + return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size)); #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 } else { // It's ok to capture cudaMallocs, as long as we never cudaFree those @@ -323,7 +323,7 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) { // Capturing cudaMalloc behaves nicely: it gives the graph new VA, // but is ignored (won't leakily allocate new memory) in replays. at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed}; - return cudaMalloc(p, size); + return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size)); } #endif } @@ -756,9 +756,9 @@ class DeviceCachingAllocator { if (*largest == 0) { // make an initial guess if a zero *largest is passed in size_t tmp_bytes; - cudaMemGetInfo( + C10_CUDA_CHECK(cudaMemGetInfo( largest, // Use free memory as an optimistic initial guess of *largest - &tmp_bytes); + &tmp_bytes)); } cache_info_aux(large_blocks, total, largest); cache_info_aux(small_blocks, total, largest); @@ -1458,7 +1458,7 @@ class DeviceCachingAllocator { cudaEvent_t event = e.first; Block* block = e.second; - cudaError_t err = cudaEventQuery(event); + cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(event)); if (err == cudaErrorNotReady) { // ignore and clear the error if not ready cudaGetLastError(); @@ -1576,9 +1576,9 @@ class THCCachingAllocator { fraction, ". Please set within (0, 1)."); int activated_device; - cudaGetDevice(&activated_device); + C10_CUDA_CHECK(cudaGetDevice(&activated_device)); if (activated_device != device) { - cudaSetDevice(device); + C10_CUDA_CHECK(cudaSetDevice(device)); } device_allocator[device]->setMemoryFraction(fraction); } diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 255d798d13f..9ab61aa1f38 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -10,16 +10,13 @@ namespace { // returns -1 on failure int32_t driver_version() { int driver_version = -1; - cudaError_t err = cudaDriverGetVersion(&driver_version); - if (err != cudaSuccess) { - cudaError_t last_err C10_UNUSED = cudaGetLastError(); - } + C10_CUDA_IGNORE_ERROR(cudaDriverGetVersion(&driver_version)); return driver_version; } int device_count_impl(bool fail_if_no_driver) { int count; - auto err = cudaGetDeviceCount(&count); + auto err = C10_CUDA_ERROR_HANDLED(cudaGetDeviceCount(&count)); if (err == cudaSuccess) { return count; } diff --git a/c10/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h index 7bb97e88b99..6d17136341c 100644 --- a/c10/cuda/CUDAStream.h +++ b/c10/cuda/CUDAStream.h @@ -111,7 +111,7 @@ class C10_CUDA_API CUDAStream { bool query() const { DeviceGuard guard{stream_.device()}; - cudaError_t err = cudaStreamQuery(stream()); + cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream())); if (err == cudaSuccess) { return true; diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index 8f5cfdc259d..583feeec260 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -41,7 +41,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { } c10::optional uncheckedGetDevice() const noexcept { int device; - auto err = cudaGetDevice(&device); + const auto err = C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device)); C10_CUDA_CHECK_WARN(err); if (err != cudaSuccess) { return c10::nullopt; @@ -164,7 +164,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { if (!event) return true; cudaEvent_t cuda_event = static_cast(event); - const cudaError_t err = cudaEventQuery(cuda_event); + const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); if (err != cudaErrorNotReady) { C10_CUDA_CHECK(err); } else {