diff --git a/CMakeLists.txt b/CMakeLists.txt index aaa5c02e10d..1083d934b80 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -251,6 +251,15 @@ cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF) cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF "USE_CUDNN" OFF) cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF) +# Binary builds will fail for cufile due to https://github.com/pytorch/builder/issues/1924 +# Using TH_BINARY_BUILD to check whether is binary build. +# USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here +if(DEFINED ENV{TH_BINARY_BUILD}) + cmake_dependent_option(USE_CUFILE "Use cuFile" OFF + "USE_CUDA AND NOT $ENV{TH_BINARY_BUILD} AND NOT WIN32" OFF) +else() + cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF) +endif() option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) diff --git a/build_variables.bzl b/build_variables.bzl index b3b58169b79..7a5362cc5ad 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -773,6 +773,7 @@ libtorch_python_cuda_core_sources = [ "torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/nvtx.cpp", "torch/csrc/cuda/utils.cpp", + "torch/csrc/cuda/GdsFile.cpp", ] libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 5115944b389..746a1da3fed 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -928,6 +928,10 @@ elseif(USE_CUDA) torch_compile_options(torch_cuda) # see cmake/public/utils.cmake target_compile_definitions(torch_cuda PRIVATE USE_CUDA) + if(USE_CUFILE) + target_link_libraries(torch_cuda PRIVATE torch::cufile) + target_compile_definitions(torch_cuda PRIVATE USE_CUFILE) + endif() if(USE_CUSPARSELT) target_link_libraries(torch_cuda PRIVATE torch::cusparselt) target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 091317acb27..ef33a316534 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -39,6 +39,7 @@ if(USE_CUDA) set(CAFFE2_USE_CUDA ${USE_CUDA}) set(CAFFE2_USE_CUDNN ${USE_CUDNN}) set(CAFFE2_USE_CUSPARSELT ${USE_CUSPARSELT}) + set(CAFFE2_USE_CUFILE ${USE_CUFILE}) set(CAFFE2_USE_NVRTC ${USE_NVRTC}) include(${CMAKE_CURRENT_LIST_DIR}/public/cuda.cmake) if(CAFFE2_USE_CUDA) @@ -60,6 +61,9 @@ if(USE_CUDA) else() caffe2_update_option(USE_CUSPARSELT OFF) endif() + if(CAFFE2_USE_CUFILE) + list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS torch::cufile) + endif() find_program(SCCACHE_EXECUTABLE sccache) if(SCCACHE_EXECUTABLE) # Using RSP/--options-file renders output noncacheable by sccache @@ -79,6 +83,7 @@ if(USE_CUDA) set(CAFFE2_USE_CUDA OFF) set(CAFFE2_USE_CUDNN OFF) set(CAFFE2_USE_CUSPARSELT OFF) + set(CAFFE2_USE_CUFILE OFF) set(CAFFE2_USE_NVRTC OFF) endif() endif() @@ -1039,7 +1044,6 @@ if(USE_ROCM) caffe2_update_option(USE_SYSTEM_NCCL ON) endif() - list(APPEND HIP_CXX_FLAGS -fPIC) list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1) list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1) diff --git a/cmake/Modules/FindCUDAToolkit.cmake b/cmake/Modules/FindCUDAToolkit.cmake index 7c8a79c5493..ec9ae530aa6 100644 --- a/cmake/Modules/FindCUDAToolkit.cmake +++ b/cmake/Modules/FindCUDAToolkit.cmake @@ -978,6 +978,14 @@ if(CUDAToolkit_FOUND) _CUDAToolkit_find_and_add_import_lib(cublas_static DEPS culibos) endif() + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.4) + _CUDAToolkit_find_and_add_import_lib(cuFile ALT cufile DEPS culibos) + _CUDAToolkit_find_and_add_import_lib(cuFile_static ALT cufile_static DEPS culibos) + + _CUDAToolkit_find_and_add_import_lib(cuFile_rdma ALT cufile_rdma DEPS cuFile culibos) + _CUDAToolkit_find_and_add_import_lib(cuFile_rdma_static ALT cufile_rdma_static DEPS cuFile_static culibos) + endif() + # cuFFTW depends on cuFFT _CUDAToolkit_find_and_add_import_lib(cufftw DEPS cufft) _CUDAToolkit_find_and_add_import_lib(cufftw_static DEPS cufft_static) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 95aac142f55..f916dd3cc96 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary) message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " USE_CUDNN : ${USE_CUDNN}") message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}") + message(STATUS " USE_CUFILE : ${USE_CUFILE}") message(STATUS " CUDA version : ${CUDA_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") @@ -83,6 +84,9 @@ function(caffe2_print_configuration_summary) if(${USE_CUSPARSELT}) message(STATUS " cuSPARSELt version : ${CUSPARSELT_VERSION}") endif() + if(${USE_CUFILE}) + message(STATUS " cufile library : ${CUDA_cuFile_LIBRARY}") + endif() message(STATUS " CUDA root directory : ${CUDA_TOOLKIT_ROOT_DIR}") message(STATUS " CUDA library : ${CUDA_cuda_driver_LIBRARY}") message(STATUS " cudart library : ${CUDA_cudart_LIBRARY}") diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index 99378d84bd6..fc5cee457df 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -244,6 +244,22 @@ else() message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support") endif() +# cufile +if(CAFFE2_USE_CUFILE) + add_library(torch::cufile INTERFACE IMPORTED) + if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) + set_property( + TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cuFile_static) + else() + set_property( + TARGET torch::cufile PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cuFile) + endif() +else() + message(STATUS "USE_CUFILE is set to 0. Compiling without cuFile support") +endif() + # curand add_library(caffe2::curand INTERFACE IMPORTED) if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 13fedd29327..328c3cecead 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -183,6 +183,7 @@ See the :doc:`documentation ` for information on how to use it. .. for tracking purposes .. py:module:: torch.cuda.comm .. py:module:: torch.cuda.error +.. py:module:: torch.cuda.gds .. py:module:: torch.cuda.graphs .. py:module:: torch.cuda.jiterator .. py:module:: torch.cuda.memory diff --git a/setup.py b/setup.py index cad26e12c58..86c7fde2227 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,9 @@ # USE_CUSPARSELT=0 # disables the cuSPARSELt build # +# USE_CUFILE=0 +# disables the cuFile build +# # USE_FBGEMM=0 # disables the FBGEMM build # diff --git a/test/test_cuda.py b/test/test_cuda.py index e61930c0334..e5e25a678e2 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -17,6 +17,8 @@ from copy import deepcopy from itertools import product from random import randint +import psutil + import torch import torch.cuda from torch import inf, nan @@ -62,6 +64,7 @@ from torch.testing._internal.common_utils import ( skipIfRocm, slowTest, subtest, + TemporaryFileName, TEST_CUDA, TEST_CUDA_GRAPH, TEST_NUMPY, @@ -3998,6 +4001,16 @@ print(f"{{r1}}, {{r2}}") x = torch.cuda.device_count() self.assertEqual(f"{x}, 1", r) + @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") + def test_gds_fails_in_ci(self): + if IS_WINDOWS or TEST_WITH_ROCM: + error_msg = "is not supported on this platform" + else: + error_msg = "cuFileHandleRegister failed" + with TemporaryFileName() as f: + with self.assertRaisesRegex(RuntimeError, error_msg): + file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) + @torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaMallocAsync(TestCase): @@ -5242,6 +5255,40 @@ class TestCudaOptims(TestCase): self.assertEqual(scaler._growth_tracker, growth_tracker) +class TestGDS(TestCase): + def _get_tmp_dir_fs_type(self): + my_path = os.path.realpath("/tmp") + root_type = "" + for part in psutil.disk_partitions(): + if part.mountpoint == "/": + root_type = part.fstype + continue + if part.mountpoint == my_path: + return part.fstype + return root_type + + @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") + def test_gds_read_write_tensors(self): + if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"): + self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem") + src1 = torch.randn(1024, device="cuda") + src2 = torch.randn(2, 1024, device="cuda") + torch.cuda.gds._gds_register_buffer(src1.untyped_storage()) + torch.cuda.gds._gds_register_buffer(src2.untyped_storage()) + dest1 = torch.empty(1024, device="cuda") + dest2 = torch.empty(2, 1024, device="cuda") + with TemporaryFileName() as f: + file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) + file.save_storage(src1.untyped_storage(), offset=0) + file.save_storage(src2.untyped_storage(), offset=src1.nbytes) + file.load_storage(dest1.untyped_storage(), offset=0) + file.load_storage(dest2.untyped_storage(), offset=src1.nbytes) + self.assertEqual(src1, dest1) + self.assertEqual(src2, dest2) + torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage()) + torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage()) + + instantiate_parametrized_tests(TestCuda) instantiate_parametrized_tests(TestCudaMallocAsync) instantiate_device_type_tests(TestCudaOptims, globals()) diff --git a/third_party/cuda.BUILD b/third_party/cuda.BUILD index a948415f913..4767231b558 100644 --- a/third_party/cuda.BUILD +++ b/third_party/cuda.BUILD @@ -60,6 +60,12 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "cufile", + srcs = ["targets/x86_64-linux/lib/libcufile.so"], + visibility = ["//visibility:public"], +) + cc_library( name = "nvrtc", srcs = [ diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 8c71c0b95f0..cc376a4c08e 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -312,6 +312,10 @@ if(USE_NUMPY) target_compile_definitions(torch_python PRIVATE USE_NUMPY) endif() +if(USE_CUFILE AND NOT USE_ROCM) + target_compile_definitions(torch_python PRIVATE USE_CUFILE) +endif() + if(HAVE_SOVERSION) set_target_properties(torch_python PROPERTIES VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index c78424fb4f5..471ceea881c 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1981,6 +1981,14 @@ def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ... +# Defined in torch/csrc/cuda/GdsFile.cpp +def _gds_register_buffer(t: Storage) -> None: ... +def _gds_deregister_buffer(t: Storage) -> None: ... +def _gds_register_handle(fd: _int) -> _int: ... +def _gds_deregister_handle(handle: _int) -> None: ... +def _gds_load_storage(handle: _int, s: Storage, offset: _int) -> None: ... +def _gds_save_storage(handle: _int, s: Storage, offset: _int) -> None: ... + # Defined in torch/csrc/cuda/python_comm.cpp def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ... def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ... diff --git a/torch/__init__.py b/torch/__init__.py index bb1bf44533c..294849d4c56 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -308,6 +308,7 @@ def _load_global_deps() -> None: "cuda_runtime": "libcudart.so.*[0-9]", "cuda_cupti": "libcupti.so.*[0-9]", "cufft": "libcufft.so.*[0-9]", + "cufile": "libcufile.so.*[0-9]", "curand": "libcurand.so.*[0-9]", "nvjitlink": "libnvJitLink.so.*[0-9]", "cusparse": "libcusparse.so.*[0-9]", diff --git a/torch/csrc/cuda/GdsFile.cpp b/torch/csrc/cuda/GdsFile.cpp new file mode 100644 index 00000000000..b95b86b3374 --- /dev/null +++ b/torch/csrc/cuda/GdsFile.cpp @@ -0,0 +1,134 @@ +#include +#include + +#if defined(USE_CUFILE) +#include + +#include +#include + +namespace { +// To get error message for cuFileRead/Write APIs that return ssize_t (-1 for +// filesystem error and a negative CUfileOpError enum value otherwise). +template < + class T, + typename std::enable_if::value, std::nullptr_t>::type = + nullptr> +std::string cuGDSFileGetErrorString(T status) { + status = std::abs(status); + return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) + : std::string(std::strerror(errno)); +} + +// To get error message for Buf/Handle registeration APIs that return +// CUfileError_t +template < + class T, + typename std::enable_if::value, std::nullptr_t>::type = + nullptr> +std::string cuGDSFileGetErrorString(T status) { + std::string errStr = cuGDSFileGetErrorString(static_cast(status.err)); + if (IS_CUDA_ERR(status)) + errStr.append(".").append( + cudaGetErrorString(static_cast(status.cu_err))); + return errStr; +} +} // namespace + +void gds_load_storage( + int64_t handle, + const at::Storage& storage, + off_t offset) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + CUfileHandle_t cf_handle = reinterpret_cast(handle); + c10::cuda::CUDAGuard gpuGuard(storage.device()); + + void* dataPtr = storage.mutable_data(); + const size_t nbytes = storage.nbytes(); + + // Read the binary file + ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0); + TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret)); +} + +void gds_save_storage( + int64_t handle, + const at::Storage& storage, + off_t offset) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + CUfileHandle_t cf_handle = reinterpret_cast(handle); + c10::cuda::CUDAGuard gpuGuard(storage.device()); + + void* dataPtr = storage.mutable_data(); + const size_t nbytes = storage.nbytes(); + + // Write device memory contents to the file + ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0); + TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret)); +} + +void gds_register_buffer(const at::Storage& storage) { + void* dataPtr = storage.mutable_data(); + const size_t nbytes = storage.nbytes(); + + CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0); + TORCH_CHECK( + status.err == CU_FILE_SUCCESS, + "cuFileBufRegister failed: ", + cuGDSFileGetErrorString(status)); + return; +} + +void gds_deregister_buffer(const at::Storage& storage) { + void* dataPtr = storage.mutable_data(); + CUfileError_t status = cuFileBufDeregister(dataPtr); + TORCH_CHECK( + status.err == CU_FILE_SUCCESS, + "cuFileBufDeregister failed: ", + cuGDSFileGetErrorString(status)); + return; +} + +int64_t gds_register_handle(int fd) { + CUfileDescr_t cf_descr; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + CUfileHandle_t cf_handle; + memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t)); + cf_descr.handle.fd = fd; + cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr); + if (status.err != CU_FILE_SUCCESS) { + TORCH_CHECK( + false, + "cuFileHandleRegister failed: ", + cuGDSFileGetErrorString(status)); + } + + // Returning cuFileHandle_t as int64_t + return reinterpret_cast(cf_handle); +} + +void gds_deregister_handle(int64_t handle) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + CUfileHandle_t cf_handle = reinterpret_cast(handle); + cuFileHandleDeregister(cf_handle); +} + +#endif + +namespace torch::cuda::shared { + +void initGdsBindings(PyObject* module) { + auto m = py::handle(module).cast(); + +#if defined(USE_CUFILE) + m.def("_gds_register_handle", &gds_register_handle); + m.def("_gds_deregister_handle", &gds_deregister_handle); + m.def("_gds_register_buffer", &gds_register_buffer); + m.def("_gds_deregister_buffer", &gds_deregister_buffer); + m.def("_gds_load_storage", &gds_load_storage); + m.def("_gds_save_storage", &gds_save_storage); +#endif +} + +} // namespace torch::cuda::shared diff --git a/torch/csrc/cuda/GdsFile.h b/torch/csrc/cuda/GdsFile.h new file mode 100644 index 00000000000..0edf927393d --- /dev/null +++ b/torch/csrc/cuda/GdsFile.h @@ -0,0 +1,7 @@ +#ifndef THCP_GDSFILE_INC +#define THCP_GDSFILE_INC + +#include + +void initGdsBindings(PyObject* module); +#endif // THCP_GDSFILE_INC diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 8051b320f2e..8249881c1cd 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -1963,6 +1964,7 @@ namespace shared { void initCudartBindings(PyObject* module); void initNvtxBindings(PyObject* module); +void initGdsBindings(PyObject* module); #if defined(USE_CUDNN) || defined(USE_ROCM) void initCudnnBindings(PyObject* module); #endif @@ -1978,6 +1980,7 @@ void initModule(PyObject* module) { #if defined(USE_CUDNN) || defined(USE_ROCM) shared::initCudnnBindings(module); #endif + shared::initGdsBindings(module); registerCudaDeviceProperties(module); registerCudaPluggableAllocator(module); } diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 4526c719f33..e536be1df14 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -25,6 +25,7 @@ from torch import device as _device from torch._utils import _dummy_type, _LazySeedTracker, classproperty from torch.types import Device +from . import gds from ._utils import _get_device_index from .graphs import ( CUDAGraph, diff --git a/torch/cuda/gds.py b/torch/cuda/gds.py new file mode 100644 index 00000000000..7cd5b882410 --- /dev/null +++ b/torch/cuda/gds.py @@ -0,0 +1,129 @@ +import os +import sys +from typing import Callable, List, Optional + +import torch +from torch.types import Storage + + +__all__: List[str] = [] + + +def _dummy_fn(name: str) -> Callable: + def fn(*args, **kwargs): # type: ignore[no-untyped-def] + raise RuntimeError(f"torch._C.{name} is not supported on this platform") + + return fn + + +if not hasattr(torch._C, "_gds_register_buffer"): + assert not hasattr(torch._C, "_gds_deregister_buffer") + assert not hasattr(torch._C, "_gds_register_handle") + assert not hasattr(torch._C, "_gds_deregister_handle") + assert not hasattr(torch._C, "_gds_load_storage") + assert not hasattr(torch._C, "_gds_save_storage") + # Define functions + torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer") + torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer") + torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle") + torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle") + torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage") + torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage") + + +def _gds_register_buffer(s: Storage) -> None: + """Registers a buffer. + + Args: + s (Storage): Buffer to register. + """ + torch._C._gds_register_buffer(s) + + +def _gds_deregister_buffer(s: Storage) -> None: + """Registers a buffer. + + Args: + s (Storage): Buffer to register. + """ + torch._C._gds_deregister_buffer(s) + + +class _GdsFile: + r"""Wrapper around cuFile. + + cuFile is a file-like interface to the GPUDirect Storage (GDS) API. + + Args: + filename (str): Name of the file to open. + flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will + be added automatically. + + .. _CUDA GPUDirect Storage Documentation: + https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api + """ + + def __init__(self, filename: str, flags: int): + if sys.platform == "win32": + raise RuntimeError("GdsFile is not supported on this platform.") + self.filename = filename + self.flags = flags + self.fd = os.open(filename, flags | os.O_DIRECT) + self.handle: Optional[int] = None + self.register_handle() + + def __del__(self) -> None: + if self.handle is not None: + self.deregister_handle() + os.close(self.fd) + + def register_handle(self) -> None: + """Registers file descriptor to cuFile Driver. + + This is a wrapper around ``cuFileHandleRegister``. + """ + assert ( + self.handle is None + ), "Cannot register a handle that is already registered." + self.handle = torch._C._gds_register_handle(self.fd) + + def deregister_handle(self) -> None: + """Deregisters file descriptor from cuFile Driver. + + This is a wrapper around ``cuFileHandleDeregister``. + """ + assert ( + self.handle is not None + ), "Cannot deregister a handle that is not registered." + torch._C._gds_deregister_handle(self.handle) + self.handle = None + + def load_storage(self, storage: Storage, offset: int = 0) -> None: + """Loads data from the file into the storage. + + This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data + will be loaded from the file at ``offset`` into the storage. + + Args: + storage (Storage): Storage to load data into. + offset (int, optional): Offset into the file to start loading from. (Default: 0) + """ + assert ( + self.handle is not None + ), "Cannot load data from a file that is not registered." + torch._C._gds_load_storage(self.handle, storage, offset) + + def save_storage(self, storage: Storage, offset: int = 0) -> None: + """Saves data from the storage into the file. + + This is a wrapper around ``cuFileWrite``. All bytes of the storage + will be written to the file at ``offset``. + + Args: + storage (Storage): Storage to save data from. + offset (int, optional): Offset into the file to start saving to. (Default: 0) + """ + assert ( + self.handle is not None + ), "Cannot save data to a file that is not registered." + torch._C._gds_save_storage(self.handle, storage, offset)