[Reland] Add wrappers for synchronous GPUDirect Storage APIs (#133489)

Reland #130633

USE_CUFILE turned off by default in this version
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133489
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki 2024-08-15 17:11:52 +00:00 committed by PyTorch MergeBot
parent c23dceb8f1
commit 018e48c337
19 changed files with 391 additions and 1 deletions

View File

@ -251,6 +251,15 @@ cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
"USE_CUDNN" OFF) "USE_CUDNN" OFF)
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF) cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
# Binary builds will fail for cufile due to https://github.com/pytorch/builder/issues/1924
# Using TH_BINARY_BUILD to check whether is binary build.
# USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here
if(DEFINED ENV{TH_BINARY_BUILD})
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF
"USE_CUDA AND NOT $ENV{TH_BINARY_BUILD} AND NOT WIN32" OFF)
else()
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF)
endif()
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_KINETO "Use Kineto profiling library" ON) option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)

View File

@ -773,6 +773,7 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp", "torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/utils.cpp", "torch/csrc/cuda/utils.cpp",
"torch/csrc/cuda/GdsFile.cpp",
] ]
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [ libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [

View File

@ -928,6 +928,10 @@ elseif(USE_CUDA)
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
target_compile_definitions(torch_cuda PRIVATE USE_CUDA) target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
if(USE_CUFILE)
target_link_libraries(torch_cuda PRIVATE torch::cufile)
target_compile_definitions(torch_cuda PRIVATE USE_CUFILE)
endif()
if(USE_CUSPARSELT) if(USE_CUSPARSELT)
target_link_libraries(torch_cuda PRIVATE torch::cusparselt) target_link_libraries(torch_cuda PRIVATE torch::cusparselt)
target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT) target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT)

View File

@ -39,6 +39,7 @@ if(USE_CUDA)
set(CAFFE2_USE_CUDA ${USE_CUDA}) set(CAFFE2_USE_CUDA ${USE_CUDA})
set(CAFFE2_USE_CUDNN ${USE_CUDNN}) set(CAFFE2_USE_CUDNN ${USE_CUDNN})
set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT}) set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT})
set(CAFFE2_USE_CUFILE ${USE_CUFILE})
set(CAFFE2_USE_NVRTC ${USE_NVRTC}) set(CAFFE2_USE_NVRTC ${USE_NVRTC})
include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake) include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake)
if(CAFFE2_USE_CUDA) if(CAFFE2_USE_CUDA)
@ -60,6 +61,9 @@ if(USE_CUDA)
else() else()
caffe2_update_option(USE_CUSPARSELT OFF) caffe2_update_option(USE_CUSPARSELT OFF)
endif() endif()
if(CAFFE2_USE_CUFILE)
list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS torch::cufile)
endif()
find_program(SCCACHE_EXECUTABLE sccache) find_program(SCCACHE_EXECUTABLE sccache)
if(SCCACHE_EXECUTABLE) if(SCCACHE_EXECUTABLE)
# Using RSP/--options-file renders output noncacheable by sccache # Using RSP/--options-file renders output noncacheable by sccache
@ -79,6 +83,7 @@ if(USE_CUDA)
set(CAFFE2_USE_CUDA OFF) set(CAFFE2_USE_CUDA OFF)
set(CAFFE2_USE_CUDNN OFF) set(CAFFE2_USE_CUDNN OFF)
set(CAFFE2_USE_CUSPARSELT OFF) set(CAFFE2_USE_CUSPARSELT OFF)
set(CAFFE2_USE_CUFILE OFF)
set(CAFFE2_USE_NVRTC OFF) set(CAFFE2_USE_NVRTC OFF)
endif() endif()
endif() endif()
@ -1039,7 +1044,6 @@ if(USE_ROCM)
caffe2_update_option(USE_SYSTEM_NCCL ON) caffe2_update_option(USE_SYSTEM_NCCL ON)
endif() endif()
list(APPEND HIP_CXX_FLAGS -fPIC) list(APPEND HIP_CXX_FLAGS -fPIC)
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1) list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1) list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)

View File

@ -978,6 +978,14 @@ if(CUDAToolkit_FOUND)
_CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos) _CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos)
endif() endif()
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.4)
_CUDAToolkit_find_and_add_import_lib(cuFile ALT cufile DEPS culibos)
_CUDAToolkit_find_and_add_import_lib(cuFile_static ALT cufile_static DEPS culibos)
_CUDAToolkit_find_and_add_import_lib(cuFile_rdma ALT cufile_rdma DEPS cuFile culibos)
_CUDAToolkit_find_and_add_import_lib(cuFile_rdma_static ALT cufile_rdma_static DEPS cuFile_static culibos)
endif()
# cuFFTW depends on cuFFT # cuFFTW depends on cuFFT
_CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft) _CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft)
_CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static) _CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static)

View File

@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
message(STATUS " USE_CUDNN : ${USE_CUDNN}") message(STATUS " USE_CUDNN : ${USE_CUDNN}")
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}") message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
message(STATUS " USE_CUFILE : ${USE_CUFILE}")
message(STATUS " CUDA version : ${CUDA_VERSION}") message(STATUS " CUDA version : ${CUDA_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
@ -83,6 +84,9 @@ function(caffe2_print_configuration_summary)
if(${USE_CUSPARSELT}) if(${USE_CUSPARSELT})
message(STATUS " cuSPARSELt version : ${CUSPARSELT_VERSION}") message(STATUS " cuSPARSELt version : ${CUSPARSELT_VERSION}")
endif() endif()
if(${USE_CUFILE})
message(STATUS " cufile library : ${CUDA_cuFile_LIBRARY}")
endif()
message(STATUS " CUDA root directory : ${CUDA_TOOLKIT_ROOT_DIR}") message(STATUS " CUDA root directory : ${CUDA_TOOLKIT_ROOT_DIR}")
message(STATUS " CUDA library : ${CUDA_cuda_driver_LIBRARY}") message(STATUS " CUDA library : ${CUDA_cuda_driver_LIBRARY}")
message(STATUS " cudart library : ${CUDA_cudart_LIBRARY}") message(STATUS " cudart library : ${CUDA_cudart_LIBRARY}")

View File

@ -244,6 +244,22 @@ else()
message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support") message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
endif() endif()
# cufile
if(CAFFE2_USE_CUFILE)
add_library(torch::cufile INTERFACE IMPORTED)
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
set_property(
TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
CUDA::cuFile_static)
else()
set_property(
TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
CUDA::cuFile)
endif()
else()
message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support")
endif()
# curand # curand
add_library(caffe2::curand INTERFACE IMPORTED) add_library(caffe2::curand INTERFACE IMPORTED)
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)

View File

@ -183,6 +183,7 @@ See the :doc:`documentation <cuda._sanitizer>` for information on how to use it.
.. for tracking purposes .. for tracking purposes
.. py:module:: torch.cuda.comm .. py:module:: torch.cuda.comm
.. py:module:: torch.cuda.error .. py:module:: torch.cuda.error
.. py:module:: torch.cuda.gds
.. py:module:: torch.cuda.graphs .. py:module:: torch.cuda.graphs
.. py:module:: torch.cuda.jiterator .. py:module:: torch.cuda.jiterator
.. py:module:: torch.cuda.memory .. py:module:: torch.cuda.memory

View File

@ -38,6 +38,9 @@
# USE_CUSPARSELT=0 # USE_CUSPARSELT=0
# disables the cuSPARSELt build # disables the cuSPARSELt build
# #
# USE_CUFILE=0
# disables the cuFile build
#
# USE_FBGEMM=0 # USE_FBGEMM=0
# disables the FBGEMM build # disables the FBGEMM build
# #

View File

@ -17,6 +17,8 @@ from copy import deepcopy
from itertools import product from itertools import product
from random import randint from random import randint
import psutil
import torch import torch
import torch.cuda import torch.cuda
from torch import inf, nan from torch import inf, nan
@ -62,6 +64,7 @@ from torch.testing._internal.common_utils import (
skipIfRocm, skipIfRocm,
slowTest, slowTest,
subtest, subtest,
TemporaryFileName,
TEST_CUDA, TEST_CUDA,
TEST_CUDA_GRAPH, TEST_CUDA_GRAPH,
TEST_NUMPY, TEST_NUMPY,
@ -3998,6 +4001,16 @@ print(f"{{r1}}, {{r2}}")
x = torch.cuda.device_count() x = torch.cuda.device_count()
self.assertEqual(f"{x}, 1", r) self.assertEqual(f"{x}, 1", r)
@unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
def test_gds_fails_in_ci(self):
if IS_WINDOWS or TEST_WITH_ROCM:
error_msg = "is not supported on this platform"
else:
error_msg = "cuFileHandleRegister failed"
with TemporaryFileName() as f:
with self.assertRaisesRegex(RuntimeError, error_msg):
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
@torch.testing._internal.common_utils.markDynamoStrictTest @torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaMallocAsync(TestCase): class TestCudaMallocAsync(TestCase):
@ -5242,6 +5255,40 @@ class TestCudaOptims(TestCase):
self.assertEqual(scaler._growth_tracker, growth_tracker) self.assertEqual(scaler._growth_tracker, growth_tracker)
class TestGDS(TestCase):
def _get_tmp_dir_fs_type(self):
my_path = os.path.realpath("/tmp")
root_type = ""
for part in psutil.disk_partitions():
if part.mountpoint == "/":
root_type = part.fstype
continue
if part.mountpoint == my_path:
return part.fstype
return root_type
@unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
def test_gds_read_write_tensors(self):
if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"):
self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem")
src1 = torch.randn(1024, device="cuda")
src2 = torch.randn(2, 1024, device="cuda")
torch.cuda.gds._gds_register_buffer(src1.untyped_storage())
torch.cuda.gds._gds_register_buffer(src2.untyped_storage())
dest1 = torch.empty(1024, device="cuda")
dest2 = torch.empty(2, 1024, device="cuda")
with TemporaryFileName() as f:
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
file.save_storage(src1.untyped_storage(), offset=0)
file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
file.load_storage(dest1.untyped_storage(), offset=0)
file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
self.assertEqual(src1, dest1)
self.assertEqual(src2, dest2)
torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage())
torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
instantiate_parametrized_tests(TestCuda) instantiate_parametrized_tests(TestCuda)
instantiate_parametrized_tests(TestCudaMallocAsync) instantiate_parametrized_tests(TestCudaMallocAsync)
instantiate_device_type_tests(TestCudaOptims, globals()) instantiate_device_type_tests(TestCudaOptims, globals())

View File

@ -60,6 +60,12 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
cc_library(
name = "cufile",
srcs = ["targets/x86_64-linux/lib/libcufile.so"],
visibility = ["//visibility:public"],
)
cc_library( cc_library(
name = "nvrtc", name = "nvrtc",
srcs = [ srcs = [

View File

@ -312,6 +312,10 @@ if(USE_NUMPY)
target_compile_definitions(torch_python PRIVATE USE_NUMPY) target_compile_definitions(torch_python PRIVATE USE_NUMPY)
endif() endif()
if(USE_CUFILE AND NOT USE_ROCM)
target_compile_definitions(torch_python PRIVATE USE_CUFILE)
endif()
if(HAVE_SOVERSION) if(HAVE_SOVERSION)
set_target_properties(torch_python PROPERTIES set_target_properties(torch_python PROPERTIES
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})

View File

@ -1981,6 +1981,14 @@ def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
# Defined in torch/csrc/cuda/GdsFile.cpp
def _gds_register_buffer(t: Storage) -> None: ...
def _gds_deregister_buffer(t: Storage) -> None: ...
def _gds_register_handle(fd: _int) -> _int: ...
def _gds_deregister_handle(handle: _int) -> None: ...
def _gds_load_storage(handle: _int, s: Storage, offset: _int) -> None: ...
def _gds_save_storage(handle: _int, s: Storage, offset: _int) -> None: ...
# Defined in torch/csrc/cuda/python_comm.cpp # Defined in torch/csrc/cuda/python_comm.cpp
def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ... def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ... def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...

View File

@ -308,6 +308,7 @@ def _load_global_deps() -> None:
"cuda_runtime": "libcudart.so.*[0-9]", "cuda_runtime": "libcudart.so.*[0-9]",
"cuda_cupti": "libcupti.so.*[0-9]", "cuda_cupti": "libcupti.so.*[0-9]",
"cufft": "libcufft.so.*[0-9]", "cufft": "libcufft.so.*[0-9]",
"cufile": "libcufile.so.*[0-9]",
"curand": "libcurand.so.*[0-9]", "curand": "libcurand.so.*[0-9]",
"nvjitlink": "libnvJitLink.so.*[0-9]", "nvjitlink": "libnvJitLink.so.*[0-9]",
"cusparse": "libcusparse.so.*[0-9]", "cusparse": "libcusparse.so.*[0-9]",

134
torch/csrc/cuda/GdsFile.cpp Normal file
View File

@ -0,0 +1,134 @@
#include <pybind11/pybind11.h>
#include <torch/csrc/utils/pybind.h>
#if defined(USE_CUFILE)
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cufile.h>
namespace {
// To get error message for cuFileRead/Write APIs that return ssize_t (-1 for
// filesystem error and a negative CUfileOpError enum value otherwise).
template <
class T,
typename std::enable_if<std::is_integral<T>::value, std::nullptr_t>::type =
nullptr>
std::string cuGDSFileGetErrorString(T status) {
status = std::abs(status);
return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
: std::string(std::strerror(errno));
}
// To get error message for Buf/Handle registeration APIs that return
// CUfileError_t
template <
class T,
typename std::enable_if<!std::is_integral<T>::value, std::nullptr_t>::type =
nullptr>
std::string cuGDSFileGetErrorString(T status) {
std::string errStr = cuGDSFileGetErrorString(static_cast<int>(status.err));
if (IS_CUDA_ERR(status))
errStr.append(".").append(
cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));
return errStr;
}
} // namespace
void gds_load_storage(
int64_t handle,
const at::Storage& storage,
off_t offset) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
c10::cuda::CUDAGuard gpuGuard(storage.device());
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
// Read the binary file
ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0);
TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret));
}
void gds_save_storage(
int64_t handle,
const at::Storage& storage,
off_t offset) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
c10::cuda::CUDAGuard gpuGuard(storage.device());
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
// Write device memory contents to the file
ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0);
TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret));
}
void gds_register_buffer(const at::Storage& storage) {
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0);
TORCH_CHECK(
status.err == CU_FILE_SUCCESS,
"cuFileBufRegister failed: ",
cuGDSFileGetErrorString(status));
return;
}
void gds_deregister_buffer(const at::Storage& storage) {
void* dataPtr = storage.mutable_data();
CUfileError_t status = cuFileBufDeregister(dataPtr);
TORCH_CHECK(
status.err == CU_FILE_SUCCESS,
"cuFileBufDeregister failed: ",
cuGDSFileGetErrorString(status));
return;
}
int64_t gds_register_handle(int fd) {
CUfileDescr_t cf_descr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CUfileHandle_t cf_handle;
memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
cf_descr.handle.fd = fd;
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
if (status.err != CU_FILE_SUCCESS) {
TORCH_CHECK(
false,
"cuFileHandleRegister failed: ",
cuGDSFileGetErrorString(status));
}
// Returning cuFileHandle_t as int64_t
return reinterpret_cast<int64_t>(cf_handle);
}
void gds_deregister_handle(int64_t handle) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
cuFileHandleDeregister(cf_handle);
}
#endif
namespace torch::cuda::shared {
void initGdsBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
#if defined(USE_CUFILE)
m.def("_gds_register_handle", &gds_register_handle);
m.def("_gds_deregister_handle", &gds_deregister_handle);
m.def("_gds_register_buffer", &gds_register_buffer);
m.def("_gds_deregister_buffer", &gds_deregister_buffer);
m.def("_gds_load_storage", &gds_load_storage);
m.def("_gds_save_storage", &gds_save_storage);
#endif
}
} // namespace torch::cuda::shared

View File

@ -0,0 +1,7 @@
#ifndef THCP_GDSFILE_INC
#define THCP_GDSFILE_INC
#include <torch/csrc/python_headers.h>
void initGdsBindings(PyObject* module);
#endif // THCP_GDSFILE_INC

View File

@ -35,6 +35,7 @@
#include <torch/csrc/CudaIPCTypes.h> #include <torch/csrc/CudaIPCTypes.h>
#include <torch/csrc/Generator.h> #include <torch/csrc/Generator.h>
#include <torch/csrc/cuda/CUDAPluggableAllocator.h> #include <torch/csrc/cuda/CUDAPluggableAllocator.h>
#include <torch/csrc/cuda/GdsFile.h>
#include <torch/csrc/cuda/THCP.h> #include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/cuda/memory_snapshot.h> #include <torch/csrc/cuda/memory_snapshot.h>
#include <torch/csrc/cuda/python_comm.h> #include <torch/csrc/cuda/python_comm.h>
@ -1963,6 +1964,7 @@ namespace shared {
void initCudartBindings(PyObject* module); void initCudartBindings(PyObject* module);
void initNvtxBindings(PyObject* module); void initNvtxBindings(PyObject* module);
void initGdsBindings(PyObject* module);
#if defined(USE_CUDNN) || defined(USE_ROCM) #if defined(USE_CUDNN) || defined(USE_ROCM)
void initCudnnBindings(PyObject* module); void initCudnnBindings(PyObject* module);
#endif #endif
@ -1978,6 +1980,7 @@ void initModule(PyObject* module) {
#if defined(USE_CUDNN) || defined(USE_ROCM) #if defined(USE_CUDNN) || defined(USE_ROCM)
shared::initCudnnBindings(module); shared::initCudnnBindings(module);
#endif #endif
shared::initGdsBindings(module);
registerCudaDeviceProperties(module); registerCudaDeviceProperties(module);
registerCudaPluggableAllocator(module); registerCudaPluggableAllocator(module);
} }

View File

@ -25,6 +25,7 @@ from torch import device as _device
from torch._utils import _dummy_type, _LazySeedTracker, classproperty from torch._utils import _dummy_type, _LazySeedTracker, classproperty
from torch.types import Device from torch.types import Device
from . import gds
from ._utils import _get_device_index from ._utils import _get_device_index
from .graphs import ( from .graphs import (
CUDAGraph, CUDAGraph,

129
torch/cuda/gds.py Normal file
View File

@ -0,0 +1,129 @@
import os
import sys
from typing import Callable, List, Optional
import torch
from torch.types import Storage
__all__: List[str] = []
def _dummy_fn(name: str) -> Callable:
def fn(*args, **kwargs): # type: ignore[no-untyped-def]
raise RuntimeError(f"torch._C.{name} is not supported on this platform")
return fn
if not hasattr(torch._C, "_gds_register_buffer"):
assert not hasattr(torch._C, "_gds_deregister_buffer")
assert not hasattr(torch._C, "_gds_register_handle")
assert not hasattr(torch._C, "_gds_deregister_handle")
assert not hasattr(torch._C, "_gds_load_storage")
assert not hasattr(torch._C, "_gds_save_storage")
# Define functions
torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
def _gds_register_buffer(s: Storage) -> None:
"""Registers a buffer.
Args:
s (Storage): Buffer to register.
"""
torch._C._gds_register_buffer(s)
def _gds_deregister_buffer(s: Storage) -> None:
"""Registers a buffer.
Args:
s (Storage): Buffer to register.
"""
torch._C._gds_deregister_buffer(s)
class _GdsFile:
r"""Wrapper around cuFile.
cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
Args:
filename (str): Name of the file to open.
flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
be added automatically.
.. _CUDA GPUDirect Storage Documentation:
https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
"""
def __init__(self, filename: str, flags: int):
if sys.platform == "win32":
raise RuntimeError("GdsFile is not supported on this platform.")
self.filename = filename
self.flags = flags
self.fd = os.open(filename, flags | os.O_DIRECT)
self.handle: Optional[int] = None
self.register_handle()
def __del__(self) -> None:
if self.handle is not None:
self.deregister_handle()
os.close(self.fd)
def register_handle(self) -> None:
"""Registers file descriptor to cuFile Driver.
This is a wrapper around ``cuFileHandleRegister``.
"""
assert (
self.handle is None
), "Cannot register a handle that is already registered."
self.handle = torch._C._gds_register_handle(self.fd)
def deregister_handle(self) -> None:
"""Deregisters file descriptor from cuFile Driver.
This is a wrapper around ``cuFileHandleDeregister``.
"""
assert (
self.handle is not None
), "Cannot deregister a handle that is not registered."
torch._C._gds_deregister_handle(self.handle)
self.handle = None
def load_storage(self, storage: Storage, offset: int = 0) -> None:
"""Loads data from the file into the storage.
This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
will be loaded from the file at ``offset`` into the storage.
Args:
storage (Storage): Storage to load data into.
offset (int, optional): Offset into the file to start loading from. (Default: 0)
"""
assert (
self.handle is not None
), "Cannot load data from a file that is not registered."
torch._C._gds_load_storage(self.handle, storage, offset)
def save_storage(self, storage: Storage, offset: int = 0) -> None:
"""Saves data from the storage into the file.
This is a wrapper around ``cuFileWrite``. All bytes of the storage
will be written to the file at ``offset``.
Args:
storage (Storage): Storage to save data from.
offset (int, optional): Offset into the file to start saving to. (Default: 0)
"""
assert (
self.handle is not None
), "Cannot save data to a file that is not registered."
torch._C._gds_save_storage(self.handle, storage, offset)