Fixes mem_get_info when querying on a device other than the current device (#69640)

Summary:
Also fixes the documentation failing to appear and adds a test to validate that op works with multiple devices properly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69640

Reviewed By: ngimel

Differential Revision: D32965391

Pulled By: mruberry

fbshipit-source-id: 4fe502809b353464da8edf62d92ca9863804f08e
This commit is contained in:
Mike Ruberry 2021-12-08 23:02:56 -08:00 committed by Facebook GitHub Bot
parent 24d885f5f8
commit dc87cf5fe1
3 changed files with 18 additions and 1 deletions

View File

@ -92,6 +92,7 @@ Memory management
empty_cache
list_gpu_processes
mem_get_info
memory_stats
memory_summary
memory_snapshot

View File

@ -1635,6 +1635,20 @@ except RuntimeError as e:
a = torch.ones(65536).cuda().half()
self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536)
# Verifies that mem_get_info works, including when called for a different device
def test_mem_get_info(self):
def _test(idx):
before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(idx)
t = torch.randn(1024 * 1024, device='cuda:' + str(idx))
after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(idx)
self.assertTrue(after_free_bytes < before_free_bytes)
self.assertEqual(before_available_bytes, after_available_bytes)
_test(0)
if TEST_MULTIGPU:
_test(1)
# Test that wrap_with_cuda_memory_check successfully detects leak
# skip for ROCM. Look into #62533.
@skipIfRocm

View File

@ -6,6 +6,8 @@
#else
#include <hip/hip_runtime_api.h>
#endif
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
namespace torch { namespace cuda { namespace shared {
@ -46,7 +48,7 @@ void initCudartBindings(PyObject* module) {
cudart.def("cuda" "ProfilerInitialize", cudaProfilerInitialize);
#endif
cudart.def("cuda" "MemGetInfo", [](int device) -> std::pair<size_t, size_t> {
C10_CUDA_CHECK(cudaGetDevice(&device));
c10::cuda::CUDAGuard guard(device);
size_t device_free;
size_t device_total;
cudaMemGetInfo(&device_free, &device_total);