From 09cbf34e9386821a2a72990a6b4870f27bc129fc Mon Sep 17 00:00:00 2001 From: albanD Date: Mon, 15 Sep 2025 13:29:43 +0000 Subject: [PATCH] [BE] Preserve caller source location in the error message (#162808) Summary: Currently the C10_CUDA_CHECK only shows source location in CUDAException like below: ``` Exception raised from c10_cuda_check_implementation at fbcode/caffe2/c10/cuda/CUDAException.cpp:44 ``` which is not terribly useful. By checking the original diff D39619861 that introduced c10_cuda_check_implementation, it seems the original macro would show the source location correctly but c10_cuda_check_implementation broke it. This diff will propagate caller source location to c10_cuda_check_implementation to fix the issue. Test Plan: CI Observed desired error message after the change: ``` CUDA error: an illegal memory access was encountered Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Device-side assertion tracking was not enabled by user. Exception raised from operator() at fbcode/sigrid/predictor/aed/AedContainer.cpp:659 (most recent call first): ``` Note the last line reports actual caller location. Rollback Plan: Reviewed By: Raymo111 Differential Revision: D81880552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162808 Approved by: https://github.com/janeyx99 --- c10/cuda/CUDAException.cpp | 8 ++++---- c10/cuda/CUDAException.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 457d35f020b..4e4419b4369 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -10,9 +10,9 @@ namespace c10::cuda { void c10_cuda_check_implementation( const int32_t err, - const char* /*filename*/, - const char* /*function_name*/, - const int /*line_number*/, + const char* filename, + const char* function_name, + const uint32_t line_number, const bool include_device_assertions) { const auto cuda_error = static_cast(err); const auto cuda_kernel_failure = include_device_assertions @@ -41,7 +41,7 @@ void c10_cuda_check_implementation( } #endif throw c10::AcceleratorError( - {__func__, __FILE__, int32_t(__LINE__)}, err, check_message); + {function_name, filename, line_number}, err, check_message); } } // namespace c10::cuda diff --git a/c10/cuda/CUDAException.h b/c10/cuda/CUDAException.h index 899d85e8a73..2503b22e476 100644 --- a/c10/cuda/CUDAException.h +++ b/c10/cuda/CUDAException.h @@ -91,7 +91,7 @@ C10_CUDA_API void c10_cuda_check_implementation( const int32_t err, const char* filename, const char* function_name, - const int line_number, + const uint32_t line_number, const bool include_device_assertions); } // namespace c10::cuda