mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes mem_get_info when querying on a device other than the current device (#69640)
Summary: Also fixes the documentation failing to appear and adds a test to validate that op works with multiple devices properly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/69640 Reviewed By: ngimel Differential Revision: D32965391 Pulled By: mruberry fbshipit-source-id: 4fe502809b353464da8edf62d92ca9863804f08e
This commit is contained in:
parent
24d885f5f8
commit
dc87cf5fe1
|
|
@ -92,6 +92,7 @@ Memory management
|
||||||
|
|
||||||
empty_cache
|
empty_cache
|
||||||
list_gpu_processes
|
list_gpu_processes
|
||||||
|
mem_get_info
|
||||||
memory_stats
|
memory_stats
|
||||||
memory_summary
|
memory_summary
|
||||||
memory_snapshot
|
memory_snapshot
|
||||||
|
|
|
||||||
|
|
@ -1635,6 +1635,20 @@ except RuntimeError as e:
|
||||||
a = torch.ones(65536).cuda().half()
|
a = torch.ones(65536).cuda().half()
|
||||||
self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536)
|
self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536)
|
||||||
|
|
||||||
|
# Verifies that mem_get_info works, including when called for a different device
|
||||||
|
def test_mem_get_info(self):
|
||||||
|
def _test(idx):
|
||||||
|
before_free_bytes, before_available_bytes = torch.cuda.mem_get_info(idx)
|
||||||
|
t = torch.randn(1024 * 1024, device='cuda:' + str(idx))
|
||||||
|
after_free_bytes, after_available_bytes = torch.cuda.mem_get_info(idx)
|
||||||
|
|
||||||
|
self.assertTrue(after_free_bytes < before_free_bytes)
|
||||||
|
self.assertEqual(before_available_bytes, after_available_bytes)
|
||||||
|
|
||||||
|
_test(0)
|
||||||
|
if TEST_MULTIGPU:
|
||||||
|
_test(1)
|
||||||
|
|
||||||
# Test that wrap_with_cuda_memory_check successfully detects leak
|
# Test that wrap_with_cuda_memory_check successfully detects leak
|
||||||
# skip for ROCM. Look into #62533.
|
# skip for ROCM. Look into #62533.
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@
|
||||||
#else
|
#else
|
||||||
#include <hip/hip_runtime_api.h>
|
#include <hip/hip_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <c10/cuda/CUDAException.h>
|
#include <c10/cuda/CUDAException.h>
|
||||||
|
|
||||||
namespace torch { namespace cuda { namespace shared {
|
namespace torch { namespace cuda { namespace shared {
|
||||||
|
|
@ -46,7 +48,7 @@ void initCudartBindings(PyObject* module) {
|
||||||
cudart.def("cuda" "ProfilerInitialize", cudaProfilerInitialize);
|
cudart.def("cuda" "ProfilerInitialize", cudaProfilerInitialize);
|
||||||
#endif
|
#endif
|
||||||
cudart.def("cuda" "MemGetInfo", [](int device) -> std::pair<size_t, size_t> {
|
cudart.def("cuda" "MemGetInfo", [](int device) -> std::pair<size_t, size_t> {
|
||||||
C10_CUDA_CHECK(cudaGetDevice(&device));
|
c10::cuda::CUDAGuard guard(device);
|
||||||
size_t device_free;
|
size_t device_free;
|
||||||
size_t device_total;
|
size_t device_total;
|
||||||
cudaMemGetInfo(&device_free, &device_total);
|
cudaMemGetInfo(&device_free, &device_total);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user