mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Reland] Add wrappers for synchronous GPUDirect Storage APIs (#133489)
Reland #130633 USE_CUFILE turned off by default in this version Pull Request resolved: https://github.com/pytorch/pytorch/pull/133489 Approved by: https://github.com/albanD
This commit is contained in:
parent
c23dceb8f1
commit
018e48c337
|
|
@ -251,6 +251,15 @@ cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
|
||||||
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
||||||
"USE_CUDNN" OFF)
|
"USE_CUDNN" OFF)
|
||||||
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
|
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
|
||||||
|
# Binary builds will fail for cufile due to https://github.com/pytorch/builder/issues/1924
|
||||||
|
# Using TH_BINARY_BUILD to check whether is binary build.
|
||||||
|
# USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here
|
||||||
|
if(DEFINED ENV{TH_BINARY_BUILD})
|
||||||
|
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF
|
||||||
|
"USE_CUDA AND NOT $ENV{TH_BINARY_BUILD} AND NOT WIN32" OFF)
|
||||||
|
else()
|
||||||
|
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF)
|
||||||
|
endif()
|
||||||
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
|
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
|
||||||
option(USE_KINETO "Use Kineto profiling library" ON)
|
option(USE_KINETO "Use Kineto profiling library" ON)
|
||||||
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
|
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
|
||||||
|
|
|
||||||
|
|
@ -773,6 +773,7 @@ libtorch_python_cuda_core_sources = [
|
||||||
"torch/csrc/cuda/shared/cudart.cpp",
|
"torch/csrc/cuda/shared/cudart.cpp",
|
||||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||||
"torch/csrc/cuda/utils.cpp",
|
"torch/csrc/cuda/utils.cpp",
|
||||||
|
"torch/csrc/cuda/GdsFile.cpp",
|
||||||
]
|
]
|
||||||
|
|
||||||
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
|
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
|
||||||
|
|
|
||||||
|
|
@ -928,6 +928,10 @@ elseif(USE_CUDA)
|
||||||
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
|
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
|
||||||
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
|
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
|
||||||
|
|
||||||
|
if(USE_CUFILE)
|
||||||
|
target_link_libraries(torch_cuda PRIVATE torch::cufile)
|
||||||
|
target_compile_definitions(torch_cuda PRIVATE USE_CUFILE)
|
||||||
|
endif()
|
||||||
if(USE_CUSPARSELT)
|
if(USE_CUSPARSELT)
|
||||||
target_link_libraries(torch_cuda PRIVATE torch::cusparselt)
|
target_link_libraries(torch_cuda PRIVATE torch::cusparselt)
|
||||||
target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT)
|
target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT)
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ if(USE_CUDA)
|
||||||
set(CAFFE2_USE_CUDA ${USE_CUDA})
|
set(CAFFE2_USE_CUDA ${USE_CUDA})
|
||||||
set(CAFFE2_USE_CUDNN ${USE_CUDNN})
|
set(CAFFE2_USE_CUDNN ${USE_CUDNN})
|
||||||
set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT})
|
set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT})
|
||||||
|
set(CAFFE2_USE_CUFILE ${USE_CUFILE})
|
||||||
set(CAFFE2_USE_NVRTC ${USE_NVRTC})
|
set(CAFFE2_USE_NVRTC ${USE_NVRTC})
|
||||||
include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake)
|
include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake)
|
||||||
if(CAFFE2_USE_CUDA)
|
if(CAFFE2_USE_CUDA)
|
||||||
|
|
@ -60,6 +61,9 @@ if(USE_CUDA)
|
||||||
else()
|
else()
|
||||||
caffe2_update_option(USE_CUSPARSELT OFF)
|
caffe2_update_option(USE_CUSPARSELT OFF)
|
||||||
endif()
|
endif()
|
||||||
|
if(CAFFE2_USE_CUFILE)
|
||||||
|
list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS torch::cufile)
|
||||||
|
endif()
|
||||||
find_program(SCCACHE_EXECUTABLE sccache)
|
find_program(SCCACHE_EXECUTABLE sccache)
|
||||||
if(SCCACHE_EXECUTABLE)
|
if(SCCACHE_EXECUTABLE)
|
||||||
# Using RSP/--options-file renders output noncacheable by sccache
|
# Using RSP/--options-file renders output noncacheable by sccache
|
||||||
|
|
@ -79,6 +83,7 @@ if(USE_CUDA)
|
||||||
set(CAFFE2_USE_CUDA OFF)
|
set(CAFFE2_USE_CUDA OFF)
|
||||||
set(CAFFE2_USE_CUDNN OFF)
|
set(CAFFE2_USE_CUDNN OFF)
|
||||||
set(CAFFE2_USE_CUSPARSELT OFF)
|
set(CAFFE2_USE_CUSPARSELT OFF)
|
||||||
|
set(CAFFE2_USE_CUFILE OFF)
|
||||||
set(CAFFE2_USE_NVRTC OFF)
|
set(CAFFE2_USE_NVRTC OFF)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
@ -1039,7 +1044,6 @@ if(USE_ROCM)
|
||||||
caffe2_update_option(USE_SYSTEM_NCCL ON)
|
caffe2_update_option(USE_SYSTEM_NCCL ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
list(APPEND HIP_CXX_FLAGS -fPIC)
|
list(APPEND HIP_CXX_FLAGS -fPIC)
|
||||||
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
|
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
|
||||||
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)
|
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)
|
||||||
|
|
|
||||||
|
|
@ -978,6 +978,14 @@ if(CUDAToolkit_FOUND)
|
||||||
_CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos)
|
_CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.4)
|
||||||
|
_CUDAToolkit_find_and_add_import_lib(cuFile ALT cufile DEPS culibos)
|
||||||
|
_CUDAToolkit_find_and_add_import_lib(cuFile_static ALT cufile_static DEPS culibos)
|
||||||
|
|
||||||
|
_CUDAToolkit_find_and_add_import_lib(cuFile_rdma ALT cufile_rdma DEPS cuFile culibos)
|
||||||
|
_CUDAToolkit_find_and_add_import_lib(cuFile_rdma_static ALT cufile_rdma_static DEPS cuFile_static culibos)
|
||||||
|
endif()
|
||||||
|
|
||||||
# cuFFTW depends on cuFFT
|
# cuFFTW depends on cuFFT
|
||||||
_CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft)
|
_CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft)
|
||||||
_CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static)
|
_CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static)
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary)
|
||||||
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
|
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
|
||||||
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
|
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
|
||||||
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
|
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
|
||||||
|
message(STATUS " USE_CUFILE : ${USE_CUFILE}")
|
||||||
message(STATUS " CUDA version : ${CUDA_VERSION}")
|
message(STATUS " CUDA version : ${CUDA_VERSION}")
|
||||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||||
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
|
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
|
||||||
|
|
@ -83,6 +84,9 @@ function(caffe2_print_configuration_summary)
|
||||||
if(${USE_CUSPARSELT})
|
if(${USE_CUSPARSELT})
|
||||||
message(STATUS " cuSPARSELt version : ${CUSPARSELT_VERSION}")
|
message(STATUS " cuSPARSELt version : ${CUSPARSELT_VERSION}")
|
||||||
endif()
|
endif()
|
||||||
|
if(${USE_CUFILE})
|
||||||
|
message(STATUS " cufile library : ${CUDA_cuFile_LIBRARY}")
|
||||||
|
endif()
|
||||||
message(STATUS " CUDA root directory : ${CUDA_TOOLKIT_ROOT_DIR}")
|
message(STATUS " CUDA root directory : ${CUDA_TOOLKIT_ROOT_DIR}")
|
||||||
message(STATUS " CUDA library : ${CUDA_cuda_driver_LIBRARY}")
|
message(STATUS " CUDA library : ${CUDA_cuda_driver_LIBRARY}")
|
||||||
message(STATUS " cudart library : ${CUDA_cudart_LIBRARY}")
|
message(STATUS " cudart library : ${CUDA_cudart_LIBRARY}")
|
||||||
|
|
|
||||||
|
|
@ -244,6 +244,22 @@ else()
|
||||||
message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
|
message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# cufile
|
||||||
|
if(CAFFE2_USE_CUFILE)
|
||||||
|
add_library(torch::cufile INTERFACE IMPORTED)
|
||||||
|
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
|
||||||
|
set_property(
|
||||||
|
TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
|
||||||
|
CUDA::cuFile_static)
|
||||||
|
else()
|
||||||
|
set_property(
|
||||||
|
TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES
|
||||||
|
CUDA::cuFile)
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support")
|
||||||
|
endif()
|
||||||
|
|
||||||
# curand
|
# curand
|
||||||
add_library(caffe2::curand INTERFACE IMPORTED)
|
add_library(caffe2::curand INTERFACE IMPORTED)
|
||||||
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
|
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,7 @@ See the :doc:`documentation <cuda._sanitizer>` for information on how to use it.
|
||||||
.. for tracking purposes
|
.. for tracking purposes
|
||||||
.. py:module:: torch.cuda.comm
|
.. py:module:: torch.cuda.comm
|
||||||
.. py:module:: torch.cuda.error
|
.. py:module:: torch.cuda.error
|
||||||
|
.. py:module:: torch.cuda.gds
|
||||||
.. py:module:: torch.cuda.graphs
|
.. py:module:: torch.cuda.graphs
|
||||||
.. py:module:: torch.cuda.jiterator
|
.. py:module:: torch.cuda.jiterator
|
||||||
.. py:module:: torch.cuda.memory
|
.. py:module:: torch.cuda.memory
|
||||||
|
|
|
||||||
3
setup.py
3
setup.py
|
|
@ -38,6 +38,9 @@
|
||||||
# USE_CUSPARSELT=0
|
# USE_CUSPARSELT=0
|
||||||
# disables the cuSPARSELt build
|
# disables the cuSPARSELt build
|
||||||
#
|
#
|
||||||
|
# USE_CUFILE=0
|
||||||
|
# disables the cuFile build
|
||||||
|
#
|
||||||
# USE_FBGEMM=0
|
# USE_FBGEMM=0
|
||||||
# disables the FBGEMM build
|
# disables the FBGEMM build
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,8 @@ from copy import deepcopy
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from torch import inf, nan
|
from torch import inf, nan
|
||||||
|
|
@ -62,6 +64,7 @@ from torch.testing._internal.common_utils import (
|
||||||
skipIfRocm,
|
skipIfRocm,
|
||||||
slowTest,
|
slowTest,
|
||||||
subtest,
|
subtest,
|
||||||
|
TemporaryFileName,
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TEST_CUDA_GRAPH,
|
TEST_CUDA_GRAPH,
|
||||||
TEST_NUMPY,
|
TEST_NUMPY,
|
||||||
|
|
@ -3998,6 +4001,16 @@ print(f"{{r1}}, {{r2}}")
|
||||||
x = torch.cuda.device_count()
|
x = torch.cuda.device_count()
|
||||||
self.assertEqual(f"{x}, 1", r)
|
self.assertEqual(f"{x}, 1", r)
|
||||||
|
|
||||||
|
@unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
|
||||||
|
def test_gds_fails_in_ci(self):
|
||||||
|
if IS_WINDOWS or TEST_WITH_ROCM:
|
||||||
|
error_msg = "is not supported on this platform"
|
||||||
|
else:
|
||||||
|
error_msg = "cuFileHandleRegister failed"
|
||||||
|
with TemporaryFileName() as f:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||||
|
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
|
||||||
|
|
||||||
|
|
||||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||||
class TestCudaMallocAsync(TestCase):
|
class TestCudaMallocAsync(TestCase):
|
||||||
|
|
@ -5242,6 +5255,40 @@ class TestCudaOptims(TestCase):
|
||||||
self.assertEqual(scaler._growth_tracker, growth_tracker)
|
self.assertEqual(scaler._growth_tracker, growth_tracker)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGDS(TestCase):
|
||||||
|
def _get_tmp_dir_fs_type(self):
|
||||||
|
my_path = os.path.realpath("/tmp")
|
||||||
|
root_type = ""
|
||||||
|
for part in psutil.disk_partitions():
|
||||||
|
if part.mountpoint == "/":
|
||||||
|
root_type = part.fstype
|
||||||
|
continue
|
||||||
|
if part.mountpoint == my_path:
|
||||||
|
return part.fstype
|
||||||
|
return root_type
|
||||||
|
|
||||||
|
@unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
|
||||||
|
def test_gds_read_write_tensors(self):
|
||||||
|
if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"):
|
||||||
|
self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem")
|
||||||
|
src1 = torch.randn(1024, device="cuda")
|
||||||
|
src2 = torch.randn(2, 1024, device="cuda")
|
||||||
|
torch.cuda.gds._gds_register_buffer(src1.untyped_storage())
|
||||||
|
torch.cuda.gds._gds_register_buffer(src2.untyped_storage())
|
||||||
|
dest1 = torch.empty(1024, device="cuda")
|
||||||
|
dest2 = torch.empty(2, 1024, device="cuda")
|
||||||
|
with TemporaryFileName() as f:
|
||||||
|
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
|
||||||
|
file.save_storage(src1.untyped_storage(), offset=0)
|
||||||
|
file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
|
||||||
|
file.load_storage(dest1.untyped_storage(), offset=0)
|
||||||
|
file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
|
||||||
|
self.assertEqual(src1, dest1)
|
||||||
|
self.assertEqual(src2, dest2)
|
||||||
|
torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage())
|
||||||
|
torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestCuda)
|
instantiate_parametrized_tests(TestCuda)
|
||||||
instantiate_parametrized_tests(TestCudaMallocAsync)
|
instantiate_parametrized_tests(TestCudaMallocAsync)
|
||||||
instantiate_device_type_tests(TestCudaOptims, globals())
|
instantiate_device_type_tests(TestCudaOptims, globals())
|
||||||
|
|
|
||||||
6
third_party/cuda.BUILD
vendored
6
third_party/cuda.BUILD
vendored
|
|
@ -60,6 +60,12 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "cufile",
|
||||||
|
srcs = ["targets/x86_64-linux/lib/libcufile.so"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "nvrtc",
|
name = "nvrtc",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
|
||||||
|
|
@ -312,6 +312,10 @@ if(USE_NUMPY)
|
||||||
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
|
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(USE_CUFILE AND NOT USE_ROCM)
|
||||||
|
target_compile_definitions(torch_python PRIVATE USE_CUFILE)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(HAVE_SOVERSION)
|
if(HAVE_SOVERSION)
|
||||||
set_target_properties(torch_python PROPERTIES
|
set_target_properties(torch_python PROPERTIES
|
||||||
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
||||||
|
|
|
||||||
|
|
@ -1981,6 +1981,14 @@ def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||||
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||||
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||||
|
|
||||||
|
# Defined in torch/csrc/cuda/GdsFile.cpp
|
||||||
|
def _gds_register_buffer(t: Storage) -> None: ...
|
||||||
|
def _gds_deregister_buffer(t: Storage) -> None: ...
|
||||||
|
def _gds_register_handle(fd: _int) -> _int: ...
|
||||||
|
def _gds_deregister_handle(handle: _int) -> None: ...
|
||||||
|
def _gds_load_storage(handle: _int, s: Storage, offset: _int) -> None: ...
|
||||||
|
def _gds_save_storage(handle: _int, s: Storage, offset: _int) -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/cuda/python_comm.cpp
|
# Defined in torch/csrc/cuda/python_comm.cpp
|
||||||
def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
|
def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
|
||||||
def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...
|
def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...
|
||||||
|
|
|
||||||
|
|
@ -308,6 +308,7 @@ def _load_global_deps() -> None:
|
||||||
"cuda_runtime": "libcudart.so.*[0-9]",
|
"cuda_runtime": "libcudart.so.*[0-9]",
|
||||||
"cuda_cupti": "libcupti.so.*[0-9]",
|
"cuda_cupti": "libcupti.so.*[0-9]",
|
||||||
"cufft": "libcufft.so.*[0-9]",
|
"cufft": "libcufft.so.*[0-9]",
|
||||||
|
"cufile": "libcufile.so.*[0-9]",
|
||||||
"curand": "libcurand.so.*[0-9]",
|
"curand": "libcurand.so.*[0-9]",
|
||||||
"nvjitlink": "libnvJitLink.so.*[0-9]",
|
"nvjitlink": "libnvJitLink.so.*[0-9]",
|
||||||
"cusparse": "libcusparse.so.*[0-9]",
|
"cusparse": "libcusparse.so.*[0-9]",
|
||||||
|
|
|
||||||
134
torch/csrc/cuda/GdsFile.cpp
Normal file
134
torch/csrc/cuda/GdsFile.cpp
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
#if defined(USE_CUFILE)
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cufile.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// To get error message for cuFileRead/Write APIs that return ssize_t (-1 for
|
||||||
|
// filesystem error and a negative CUfileOpError enum value otherwise).
|
||||||
|
template <
|
||||||
|
class T,
|
||||||
|
typename std::enable_if<std::is_integral<T>::value, std::nullptr_t>::type =
|
||||||
|
nullptr>
|
||||||
|
std::string cuGDSFileGetErrorString(T status) {
|
||||||
|
status = std::abs(status);
|
||||||
|
return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
|
||||||
|
: std::string(std::strerror(errno));
|
||||||
|
}
|
||||||
|
|
||||||
|
// To get error message for Buf/Handle registeration APIs that return
|
||||||
|
// CUfileError_t
|
||||||
|
template <
|
||||||
|
class T,
|
||||||
|
typename std::enable_if<!std::is_integral<T>::value, std::nullptr_t>::type =
|
||||||
|
nullptr>
|
||||||
|
std::string cuGDSFileGetErrorString(T status) {
|
||||||
|
std::string errStr = cuGDSFileGetErrorString(static_cast<int>(status.err));
|
||||||
|
if (IS_CUDA_ERR(status))
|
||||||
|
errStr.append(".").append(
|
||||||
|
cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));
|
||||||
|
return errStr;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void gds_load_storage(
|
||||||
|
int64_t handle,
|
||||||
|
const at::Storage& storage,
|
||||||
|
off_t offset) {
|
||||||
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||||
|
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
|
||||||
|
c10::cuda::CUDAGuard gpuGuard(storage.device());
|
||||||
|
|
||||||
|
void* dataPtr = storage.mutable_data();
|
||||||
|
const size_t nbytes = storage.nbytes();
|
||||||
|
|
||||||
|
// Read the binary file
|
||||||
|
ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0);
|
||||||
|
TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret));
|
||||||
|
}
|
||||||
|
|
||||||
|
void gds_save_storage(
|
||||||
|
int64_t handle,
|
||||||
|
const at::Storage& storage,
|
||||||
|
off_t offset) {
|
||||||
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||||
|
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
|
||||||
|
c10::cuda::CUDAGuard gpuGuard(storage.device());
|
||||||
|
|
||||||
|
void* dataPtr = storage.mutable_data();
|
||||||
|
const size_t nbytes = storage.nbytes();
|
||||||
|
|
||||||
|
// Write device memory contents to the file
|
||||||
|
ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0);
|
||||||
|
TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret));
|
||||||
|
}
|
||||||
|
|
||||||
|
void gds_register_buffer(const at::Storage& storage) {
|
||||||
|
void* dataPtr = storage.mutable_data();
|
||||||
|
const size_t nbytes = storage.nbytes();
|
||||||
|
|
||||||
|
CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0);
|
||||||
|
TORCH_CHECK(
|
||||||
|
status.err == CU_FILE_SUCCESS,
|
||||||
|
"cuFileBufRegister failed: ",
|
||||||
|
cuGDSFileGetErrorString(status));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void gds_deregister_buffer(const at::Storage& storage) {
|
||||||
|
void* dataPtr = storage.mutable_data();
|
||||||
|
CUfileError_t status = cuFileBufDeregister(dataPtr);
|
||||||
|
TORCH_CHECK(
|
||||||
|
status.err == CU_FILE_SUCCESS,
|
||||||
|
"cuFileBufDeregister failed: ",
|
||||||
|
cuGDSFileGetErrorString(status));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t gds_register_handle(int fd) {
|
||||||
|
CUfileDescr_t cf_descr;
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||||
|
CUfileHandle_t cf_handle;
|
||||||
|
memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
|
||||||
|
cf_descr.handle.fd = fd;
|
||||||
|
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
|
||||||
|
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
|
||||||
|
if (status.err != CU_FILE_SUCCESS) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"cuFileHandleRegister failed: ",
|
||||||
|
cuGDSFileGetErrorString(status));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returning cuFileHandle_t as int64_t
|
||||||
|
return reinterpret_cast<int64_t>(cf_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gds_deregister_handle(int64_t handle) {
|
||||||
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||||
|
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
|
||||||
|
cuFileHandleDeregister(cf_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace torch::cuda::shared {
|
||||||
|
|
||||||
|
void initGdsBindings(PyObject* module) {
|
||||||
|
auto m = py::handle(module).cast<py::module>();
|
||||||
|
|
||||||
|
#if defined(USE_CUFILE)
|
||||||
|
m.def("_gds_register_handle", &gds_register_handle);
|
||||||
|
m.def("_gds_deregister_handle", &gds_deregister_handle);
|
||||||
|
m.def("_gds_register_buffer", &gds_register_buffer);
|
||||||
|
m.def("_gds_deregister_buffer", &gds_deregister_buffer);
|
||||||
|
m.def("_gds_load_storage", &gds_load_storage);
|
||||||
|
m.def("_gds_save_storage", &gds_save_storage);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::cuda::shared
|
||||||
7
torch/csrc/cuda/GdsFile.h
Normal file
7
torch/csrc/cuda/GdsFile.h
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
#ifndef THCP_GDSFILE_INC
|
||||||
|
#define THCP_GDSFILE_INC
|
||||||
|
|
||||||
|
#include <torch/csrc/python_headers.h>
|
||||||
|
|
||||||
|
void initGdsBindings(PyObject* module);
|
||||||
|
#endif // THCP_GDSFILE_INC
|
||||||
|
|
@ -35,6 +35,7 @@
|
||||||
#include <torch/csrc/CudaIPCTypes.h>
|
#include <torch/csrc/CudaIPCTypes.h>
|
||||||
#include <torch/csrc/Generator.h>
|
#include <torch/csrc/Generator.h>
|
||||||
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||||
|
#include <torch/csrc/cuda/GdsFile.h>
|
||||||
#include <torch/csrc/cuda/THCP.h>
|
#include <torch/csrc/cuda/THCP.h>
|
||||||
#include <torch/csrc/cuda/memory_snapshot.h>
|
#include <torch/csrc/cuda/memory_snapshot.h>
|
||||||
#include <torch/csrc/cuda/python_comm.h>
|
#include <torch/csrc/cuda/python_comm.h>
|
||||||
|
|
@ -1963,6 +1964,7 @@ namespace shared {
|
||||||
|
|
||||||
void initCudartBindings(PyObject* module);
|
void initCudartBindings(PyObject* module);
|
||||||
void initNvtxBindings(PyObject* module);
|
void initNvtxBindings(PyObject* module);
|
||||||
|
void initGdsBindings(PyObject* module);
|
||||||
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
||||||
void initCudnnBindings(PyObject* module);
|
void initCudnnBindings(PyObject* module);
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -1978,6 +1980,7 @@ void initModule(PyObject* module) {
|
||||||
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
||||||
shared::initCudnnBindings(module);
|
shared::initCudnnBindings(module);
|
||||||
#endif
|
#endif
|
||||||
|
shared::initGdsBindings(module);
|
||||||
registerCudaDeviceProperties(module);
|
registerCudaDeviceProperties(module);
|
||||||
registerCudaPluggableAllocator(module);
|
registerCudaPluggableAllocator(module);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from torch import device as _device
|
||||||
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
|
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||||
from torch.types import Device
|
from torch.types import Device
|
||||||
|
|
||||||
|
from . import gds
|
||||||
from ._utils import _get_device_index
|
from ._utils import _get_device_index
|
||||||
from .graphs import (
|
from .graphs import (
|
||||||
CUDAGraph,
|
CUDAGraph,
|
||||||
|
|
|
||||||
129
torch/cuda/gds.py
Normal file
129
torch/cuda/gds.py
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.types import Storage
|
||||||
|
|
||||||
|
|
||||||
|
__all__: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_fn(name: str) -> Callable:
|
||||||
|
def fn(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
|
raise RuntimeError(f"torch._C.{name} is not supported on this platform")
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
if not hasattr(torch._C, "_gds_register_buffer"):
|
||||||
|
assert not hasattr(torch._C, "_gds_deregister_buffer")
|
||||||
|
assert not hasattr(torch._C, "_gds_register_handle")
|
||||||
|
assert not hasattr(torch._C, "_gds_deregister_handle")
|
||||||
|
assert not hasattr(torch._C, "_gds_load_storage")
|
||||||
|
assert not hasattr(torch._C, "_gds_save_storage")
|
||||||
|
# Define functions
|
||||||
|
torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
|
||||||
|
torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
|
||||||
|
torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
|
||||||
|
torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
|
||||||
|
torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
|
||||||
|
torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
|
||||||
|
|
||||||
|
|
||||||
|
def _gds_register_buffer(s: Storage) -> None:
|
||||||
|
"""Registers a buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s (Storage): Buffer to register.
|
||||||
|
"""
|
||||||
|
torch._C._gds_register_buffer(s)
|
||||||
|
|
||||||
|
|
||||||
|
def _gds_deregister_buffer(s: Storage) -> None:
|
||||||
|
"""Registers a buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s (Storage): Buffer to register.
|
||||||
|
"""
|
||||||
|
torch._C._gds_deregister_buffer(s)
|
||||||
|
|
||||||
|
|
||||||
|
class _GdsFile:
|
||||||
|
r"""Wrapper around cuFile.
|
||||||
|
|
||||||
|
cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): Name of the file to open.
|
||||||
|
flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
|
||||||
|
be added automatically.
|
||||||
|
|
||||||
|
.. _CUDA GPUDirect Storage Documentation:
|
||||||
|
https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filename: str, flags: int):
|
||||||
|
if sys.platform == "win32":
|
||||||
|
raise RuntimeError("GdsFile is not supported on this platform.")
|
||||||
|
self.filename = filename
|
||||||
|
self.flags = flags
|
||||||
|
self.fd = os.open(filename, flags | os.O_DIRECT)
|
||||||
|
self.handle: Optional[int] = None
|
||||||
|
self.register_handle()
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
if self.handle is not None:
|
||||||
|
self.deregister_handle()
|
||||||
|
os.close(self.fd)
|
||||||
|
|
||||||
|
def register_handle(self) -> None:
|
||||||
|
"""Registers file descriptor to cuFile Driver.
|
||||||
|
|
||||||
|
This is a wrapper around ``cuFileHandleRegister``.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
self.handle is None
|
||||||
|
), "Cannot register a handle that is already registered."
|
||||||
|
self.handle = torch._C._gds_register_handle(self.fd)
|
||||||
|
|
||||||
|
def deregister_handle(self) -> None:
|
||||||
|
"""Deregisters file descriptor from cuFile Driver.
|
||||||
|
|
||||||
|
This is a wrapper around ``cuFileHandleDeregister``.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
self.handle is not None
|
||||||
|
), "Cannot deregister a handle that is not registered."
|
||||||
|
torch._C._gds_deregister_handle(self.handle)
|
||||||
|
self.handle = None
|
||||||
|
|
||||||
|
def load_storage(self, storage: Storage, offset: int = 0) -> None:
|
||||||
|
"""Loads data from the file into the storage.
|
||||||
|
|
||||||
|
This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
|
||||||
|
will be loaded from the file at ``offset`` into the storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage (Storage): Storage to load data into.
|
||||||
|
offset (int, optional): Offset into the file to start loading from. (Default: 0)
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
self.handle is not None
|
||||||
|
), "Cannot load data from a file that is not registered."
|
||||||
|
torch._C._gds_load_storage(self.handle, storage, offset)
|
||||||
|
|
||||||
|
def save_storage(self, storage: Storage, offset: int = 0) -> None:
|
||||||
|
"""Saves data from the storage into the file.
|
||||||
|
|
||||||
|
This is a wrapper around ``cuFileWrite``. All bytes of the storage
|
||||||
|
will be written to the file at ``offset``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage (Storage): Storage to save data from.
|
||||||
|
offset (int, optional): Offset into the file to start saving to. (Default: 0)
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
self.handle is not None
|
||||||
|
), "Cannot save data to a file that is not registered."
|
||||||
|
torch._C._gds_save_storage(self.handle, storage, offset)
|
||||||
Loading…
Reference in New Issue
Block a user