pytorch/docs/source/cuda.rst
Emilio Castillo c9d4390d13 Add Pluggable CUDA allocator backend (#86786)
Fixes #43144

This uses the Backend system added by [82682](https://github.com/pytorch/pytorch/pull/82682) to change allocators dynamically during the code execution. This will allow us to use RMM, use CUDA managed memory for some portions of the code that do not fit in GPU memory. Write static memory allocators to reduce fragmentation while training models and improve interoperability with external DL compilers/libraries.

For example, we could have the following allocator in c++

```c++
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>

extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
   void *ptr;
   std::cout<<"alloc "<< size<<std::endl;
   cudaMalloc(&ptr, size);
   return ptr;
}

void my_free(void* ptr) {
   std::cout<<"free "<<std::endl;
   cudaFree(ptr);
}
}
```

Compile it as a shared library
```
nvcc allocator.cc -o alloc.so -shared --compiler-options '-fPIC'
```

And use it from PyTorch as follows

```python
import torch

# Init caching
# b = torch.zeros(10, device='cuda')
new_alloc = torch.cuda.memory.CUDAPluggableAllocator('alloc.so', 'my_malloc', 'my_free')
old = torch.cuda.memory.get_current_allocator()
torch.cuda.memory.change_current_allocator(new_alloc)
b = torch.zeros(10, device='cuda')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(old)
```

Things to discuss
- How to test this, needs compiling external code ...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86786
Approved by: https://github.com/albanD
2022-11-23 17:54:36 +00:00

153 lines
2.9 KiB
ReStructuredText

torch.cuda
===================================
.. automodule:: torch.cuda
.. currentmodule:: torch.cuda
.. autosummary::
:toctree: generated
:nosignatures:
StreamContext
can_device_access_peer
current_blas_handle
current_device
current_stream
default_stream
device
device_count
device_of
get_arch_list
get_device_capability
get_device_name
get_device_properties
get_gencode_flags
get_sync_debug_mode
init
ipc_collect
is_available
is_initialized
memory_usage
set_device
set_stream
set_sync_debug_mode
stream
synchronize
utilization
OutOfMemoryError
Random Number Generator
-------------------------
.. autosummary::
:toctree: generated
:nosignatures:
get_rng_state
get_rng_state_all
set_rng_state
set_rng_state_all
manual_seed
manual_seed_all
seed
seed_all
initial_seed
Communication collectives
-------------------------
.. autosummary::
:toctree: generated
:nosignatures:
comm.broadcast
comm.broadcast_coalesced
comm.reduce_add
comm.scatter
comm.gather
Streams and events
------------------
.. autosummary::
:toctree: generated
:nosignatures:
Stream
ExternalStream
Event
Graphs (beta)
-------------
.. autosummary::
:toctree: generated
:nosignatures:
is_current_stream_capturing
graph_pool_handle
CUDAGraph
graph
make_graphed_callables
.. _cuda-memory-management-api:
Memory management
-----------------
.. autosummary::
:toctree: generated
:nosignatures:
empty_cache
list_gpu_processes
mem_get_info
memory_stats
memory_summary
memory_snapshot
memory_allocated
max_memory_allocated
reset_max_memory_allocated
memory_reserved
max_memory_reserved
set_per_process_memory_fraction
memory_cached
max_memory_cached
reset_max_memory_cached
reset_peak_memory_stats
caching_allocator_alloc
caching_allocator_delete
get_allocator_backend
CUDAPluggableAllocator
change_current_allocator
.. FIXME The following doesn't seem to exist. Is it supposed to?
https://github.com/pytorch/pytorch/issues/27785
.. autofunction:: reset_max_memory_reserved
NVIDIA Tools Extension (NVTX)
-----------------------------
.. autosummary::
:toctree: generated
:nosignatures:
nvtx.mark
nvtx.range_push
nvtx.range_pop
Jiterator (beta)
-----------------------------
.. autosummary::
:toctree: generated
:nosignatures:
jiterator._create_jit_fn
jiterator._create_multi_output_jit_fn
Stream Sanitizer (prototype)
----------------------------
CUDA Sanitizer is a prototype tool for detecting synchronization errors between streams in PyTorch.
See the :doc:`documentation <cuda._sanitizer>` for information on how to use it.
.. toctree::
:hidden:
cuda._sanitizer