[cuda][cupy] Improve cupy device placement when device is provided with explicit index (#158529)

resubmit https://github.com/pytorch/pytorch/pull/158320 , fixing a potential bug when device index is not specified explicitly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158529
Approved by: https://github.com/ezyang
This commit is contained in:
Kaichao You 2025-08-15 00:27:42 +00:00 committed by PyTorch MergeBot
parent dc194a3096
commit dae7710bf2
5 changed files with 129 additions and 6 deletions

View File

@ -0,0 +1,110 @@
# Owner(s): ["oncall: distributed"]
# To run:
# python test/distributed/test_cupy_as_tensor.py
from dataclasses import dataclass
import torch
from torch.multiprocessing.reductions import reduce_tensor
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import (
requires_cuda_p2p_access,
run_tests,
skipIfRocm,
)
# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)
@dataclass
class CupyWrapper:
data_ptr: int
size_in_bytes: int
@property
def __cuda_array_interface__(self):
return {
"shape": (self.size_in_bytes,),
"typestr": "|u1",
"data": (self.data_ptr, False),
"version": 3,
}
def from_buffer(
data_ptr: int, size_in_bytes: int, device: str, dtype: torch.dtype
) -> torch.Tensor:
data = torch.as_tensor(CupyWrapper(data_ptr, size_in_bytes), device=device).view(
dtype
)
assert data.data_ptr() == data_ptr
return data
@requires_cuda_p2p_access()
class CupyAsTensorTest(MultiProcContinousTest):
@classmethod
def backend_str(cls):
return "gloo"
def _init_device(self) -> None:
# need to use vmm api to test it,
# see https://forums.developer.nvidia.com/t/inconsistent-behavior-of-cudapointergetattributes-between-cudamalloc-ipc-and-vmm-based-ipc/339025/5 # noqa: B950
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
# init and pin the process to the device
device_module.set_device(self.device)
torch.empty(1, device=self.device)
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_cupy_as_tensor(self) -> None:
"""
Test that torch.as_tensor works for cupy array interface
with zero-copy when the pointer is p2p-shared across processes.
"""
self._init_device()
tensor: torch.Tensor
if self.rank == 1:
# it seems only error from rank non-zero will be caught by this test
tensor = torch.randn(2333, device=self.device)
tensor_meta = reduce_tensor(tensor)
torch.distributed.broadcast_object_list([tensor_meta], src=1)
else:
recv_list = [None]
torch.distributed.broadcast_object_list(recv_list, src=1)
tensor_meta = recv_list[0]
func, args = tensor_meta
args = list(args)
args[6] = self.rank
ipc_tensor = func(*args)
tensor = from_buffer(
ipc_tensor.data_ptr(),
ipc_tensor.numel() * ipc_tensor.element_size(),
self.device,
ipc_tensor.dtype,
)
torch.distributed.barrier()
if self.rank == 1:
tensor.fill_(1)
device_module.synchronize()
torch.distributed.barrier()
assert tensor.allclose(tensor, 1)
torch.distributed.barrier()
@classmethod
def tearDownClass(cls):
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
super().tearDownClass()
if __name__ == "__main__":
run_tests()

View File

@ -1006,7 +1006,8 @@ If :attr:`data` is a NumPy array (an ndarray) with the same dtype and device the
tensor is constructed using :func:`torch.from_numpy`.
If :attr:`data` is a CuPy array, the returned tensor will be located on the same device as the CuPy array unless
specifically overwritten by :attr:`device` or a default device.
specifically overwritten by :attr:`device` or a default device. The device of the CuPy array is inferred from the
pointer of the array using `cudaPointerGetAttributes` unless :attr:`device` is provided with an explicit device index.
.. seealso::

View File

@ -304,7 +304,7 @@ Tensor internal_new_from_data(
TORCH_CHECK(
!pin_memory,
"Can't pin tensor constructed from __cuda_array_interface__");
auto tensor = tensor_from_cuda_array_interface(data);
auto tensor = tensor_from_cuda_array_interface(data, device_opt);
const auto& inferred_scalar_type =
type_inference ? tensor.scalar_type() : scalar_type;

View File

@ -27,7 +27,9 @@ bool is_numpy_int(PyObject* obj) {
bool is_numpy_scalar(PyObject* obj) {
throw std::runtime_error("PyTorch was compiled without NumPy support");
}
at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
at::Tensor tensor_from_cuda_array_interface(
PyObject* obj,
std::optional<c10::Device> device_opt) {
throw std::runtime_error("PyTorch was compiled without NumPy support");
}
@ -380,7 +382,9 @@ bool is_numpy_scalar(PyObject* obj) {
PyArray_IsScalar(obj, ComplexFloating));
}
at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
at::Tensor tensor_from_cuda_array_interface(
PyObject* obj,
std::optional<c10::Device> device_opt) {
if (!is_numpy_available()) {
throw std::runtime_error("Numpy is not available");
}
@ -489,7 +493,13 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
// ref:
// https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html#cuda-array-interface-version-3
if (data_ptr != nullptr) {
return {};
if (device_opt.has_value() && device_opt->has_index()) {
// if device_opt is provided with explicit device index, use it
return device_opt;
} else {
// otherwise infer from cudaPointerGetAttributes later in from_blob
return std::nullopt;
}
} else {
const auto current_device = at::detail::getCUDAHooks().getCurrentDevice();
return Device(

View File

@ -22,7 +22,9 @@ TORCH_API bool is_numpy_bool(PyObject* obj);
TORCH_API bool is_numpy_scalar(PyObject* obj);
void warn_numpy_not_writeable();
at::Tensor tensor_from_cuda_array_interface(PyObject* obj);
at::Tensor tensor_from_cuda_array_interface(
PyObject* obj,
std::optional<c10::Device> device_opt = std::nullopt);
void validate_numpy_for_dlpack_deleter_bug();
bool is_numpy_dlpack_deleter_bugged();