diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 7f557a9a873..16c5a774635 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -92,6 +92,7 @@ Memory management empty_cache list_gpu_processes + mem_get_info memory_stats memory_summary memory_snapshot diff --git a/test/test_cuda.py b/test/test_cuda.py index 5ee07168ad7..b99a6246ea4 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1635,6 +1635,20 @@ except RuntimeError as e: a = torch.ones(65536).cuda().half() self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536) + # Verifies that mem_get_info works, including when called for a different device + def test_mem_get_info(self): + def _test(idx): + before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(idx) + t = torch.randn(1024 * 1024, device='cuda:' + str(idx)) + after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(idx) + + self.assertTrue(after_free_bytes < before_free_bytes) + self.assertEqual(before_available_bytes, after_available_bytes) + + _test(0) + if TEST_MULTIGPU: + _test(1) + # Test that wrap_with_cuda_memory_check successfully detects leak # skip for ROCM. Look into #62533. @skipIfRocm diff --git a/torch/csrc/cuda/shared/cudart.cpp b/torch/csrc/cuda/shared/cudart.cpp index be6010f71ab..b93d921a16a 100644 --- a/torch/csrc/cuda/shared/cudart.cpp +++ b/torch/csrc/cuda/shared/cudart.cpp @@ -6,6 +6,8 @@ #else #include #endif + +#include #include namespace torch { namespace cuda { namespace shared { @@ -46,7 +48,7 @@ void initCudartBindings(PyObject* module) { cudart.def("cuda" "ProfilerInitialize", cudaProfilerInitialize); #endif cudart.def("cuda" "MemGetInfo", [](int device) -> std::pair { - C10_CUDA_CHECK(cudaGetDevice(&device)); + c10::cuda::CUDAGuard guard(device); size_t device_free; size_t device_total; cudaMemGetInfo(&device_free, &device_total);