mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cuda][cupy] Improve cupy device placement when device is provided with explicit index (#158529)
resubmit https://github.com/pytorch/pytorch/pull/158320 , fixing a potential bug when device index is not specified explicitly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158529 Approved by: https://github.com/ezyang
This commit is contained in:
parent
dc194a3096
commit
dae7710bf2
110
test/distributed/test_cupy_as_tensor.py
Normal file
110
test/distributed/test_cupy_as_tensor.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
# To run:
|
||||
# python test/distributed/test_cupy_as_tensor.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
requires_cuda_p2p_access,
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
)
|
||||
|
||||
|
||||
# So that tests are written in device-agnostic way
|
||||
device_type = "cuda"
|
||||
device_module = torch.get_device_module(device_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CupyWrapper:
|
||||
data_ptr: int
|
||||
size_in_bytes: int
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self):
|
||||
return {
|
||||
"shape": (self.size_in_bytes,),
|
||||
"typestr": "|u1",
|
||||
"data": (self.data_ptr, False),
|
||||
"version": 3,
|
||||
}
|
||||
|
||||
|
||||
def from_buffer(
|
||||
data_ptr: int, size_in_bytes: int, device: str, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
data = torch.as_tensor(CupyWrapper(data_ptr, size_in_bytes), device=device).view(
|
||||
dtype
|
||||
)
|
||||
assert data.data_ptr() == data_ptr
|
||||
return data
|
||||
|
||||
|
||||
@requires_cuda_p2p_access()
|
||||
class CupyAsTensorTest(MultiProcContinousTest):
|
||||
@classmethod
|
||||
def backend_str(cls):
|
||||
return "gloo"
|
||||
|
||||
def _init_device(self) -> None:
|
||||
# need to use vmm api to test it,
|
||||
# see https://forums.developer.nvidia.com/t/inconsistent-behavior-of-cudapointergetattributes-between-cudamalloc-ipc-and-vmm-based-ipc/339025/5 # noqa: B950
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
||||
# init and pin the process to the device
|
||||
device_module.set_device(self.device)
|
||||
torch.empty(1, device=self.device)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@skipIfRocm
|
||||
def test_cupy_as_tensor(self) -> None:
|
||||
"""
|
||||
Test that torch.as_tensor works for cupy array interface
|
||||
with zero-copy when the pointer is p2p-shared across processes.
|
||||
"""
|
||||
self._init_device()
|
||||
|
||||
tensor: torch.Tensor
|
||||
if self.rank == 1:
|
||||
# it seems only error from rank non-zero will be caught by this test
|
||||
tensor = torch.randn(2333, device=self.device)
|
||||
tensor_meta = reduce_tensor(tensor)
|
||||
torch.distributed.broadcast_object_list([tensor_meta], src=1)
|
||||
else:
|
||||
recv_list = [None]
|
||||
torch.distributed.broadcast_object_list(recv_list, src=1)
|
||||
tensor_meta = recv_list[0]
|
||||
func, args = tensor_meta
|
||||
args = list(args)
|
||||
args[6] = self.rank
|
||||
ipc_tensor = func(*args)
|
||||
tensor = from_buffer(
|
||||
ipc_tensor.data_ptr(),
|
||||
ipc_tensor.numel() * ipc_tensor.element_size(),
|
||||
self.device,
|
||||
ipc_tensor.dtype,
|
||||
)
|
||||
|
||||
torch.distributed.barrier()
|
||||
if self.rank == 1:
|
||||
tensor.fill_(1)
|
||||
device_module.synchronize()
|
||||
torch.distributed.barrier()
|
||||
assert tensor.allclose(tensor, 1)
|
||||
torch.distributed.barrier()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
||||
super().tearDownClass()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -1006,7 +1006,8 @@ If :attr:`data` is a NumPy array (an ndarray) with the same dtype and device the
|
|||
tensor is constructed using :func:`torch.from_numpy`.
|
||||
|
||||
If :attr:`data` is a CuPy array, the returned tensor will be located on the same device as the CuPy array unless
|
||||
specifically overwritten by :attr:`device` or a default device.
|
||||
specifically overwritten by :attr:`device` or a default device. The device of the CuPy array is inferred from the
|
||||
pointer of the array using `cudaPointerGetAttributes` unless :attr:`device` is provided with an explicit device index.
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ Tensor internal_new_from_data(
|
|||
TORCH_CHECK(
|
||||
!pin_memory,
|
||||
"Can't pin tensor constructed from __cuda_array_interface__");
|
||||
auto tensor = tensor_from_cuda_array_interface(data);
|
||||
auto tensor = tensor_from_cuda_array_interface(data, device_opt);
|
||||
const auto& inferred_scalar_type =
|
||||
type_inference ? tensor.scalar_type() : scalar_type;
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,9 @@ bool is_numpy_int(PyObject* obj) {
|
|||
bool is_numpy_scalar(PyObject* obj) {
|
||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
||||
}
|
||||
at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
|
||||
at::Tensor tensor_from_cuda_array_interface(
|
||||
PyObject* obj,
|
||||
std::optional<c10::Device> device_opt) {
|
||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
||||
}
|
||||
|
||||
|
|
@ -380,7 +382,9 @@ bool is_numpy_scalar(PyObject* obj) {
|
|||
PyArray_IsScalar(obj, ComplexFloating));
|
||||
}
|
||||
|
||||
at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
|
||||
at::Tensor tensor_from_cuda_array_interface(
|
||||
PyObject* obj,
|
||||
std::optional<c10::Device> device_opt) {
|
||||
if (!is_numpy_available()) {
|
||||
throw std::runtime_error("Numpy is not available");
|
||||
}
|
||||
|
|
@ -489,7 +493,13 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
|
|||
// ref:
|
||||
// https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html#cuda-array-interface-version-3
|
||||
if (data_ptr != nullptr) {
|
||||
return {};
|
||||
if (device_opt.has_value() && device_opt->has_index()) {
|
||||
// if device_opt is provided with explicit device index, use it
|
||||
return device_opt;
|
||||
} else {
|
||||
// otherwise infer from cudaPointerGetAttributes later in from_blob
|
||||
return std::nullopt;
|
||||
}
|
||||
} else {
|
||||
const auto current_device = at::detail::getCUDAHooks().getCurrentDevice();
|
||||
return Device(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,9 @@ TORCH_API bool is_numpy_bool(PyObject* obj);
|
|||
TORCH_API bool is_numpy_scalar(PyObject* obj);
|
||||
|
||||
void warn_numpy_not_writeable();
|
||||
at::Tensor tensor_from_cuda_array_interface(PyObject* obj);
|
||||
at::Tensor tensor_from_cuda_array_interface(
|
||||
PyObject* obj,
|
||||
std::optional<c10::Device> device_opt = std::nullopt);
|
||||
|
||||
void validate_numpy_for_dlpack_deleter_bug();
|
||||
bool is_numpy_dlpack_deleter_bugged();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user