mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
pin_memory support for NT (#110404)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110404 Approved by: https://github.com/cpuhrsch, https://github.com/albanD ghstack dependencies: #110292
This commit is contained in:
parent
cc1de49340
commit
3597325bc7
|
|
@ -4405,7 +4405,7 @@
|
|||
- func: is_pinned(Tensor self, Device? device=None) -> bool
|
||||
variants: method
|
||||
dispatch:
|
||||
CUDA: is_pinned_cuda
|
||||
NestedTensorCUDA, CUDA: is_pinned_cuda
|
||||
MPS: is_pinned_mps
|
||||
CompositeExplicitAutograd: is_pinned_default
|
||||
|
||||
|
|
@ -4419,6 +4419,7 @@
|
|||
dispatch:
|
||||
CUDA: _pin_memory_cuda
|
||||
MPS: _pin_memory_mps
|
||||
NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested
|
||||
autogen: _pin_memory.out
|
||||
|
||||
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
|
||||
|
|
|
|||
|
|
@ -132,5 +132,15 @@ Tensor cos_nested(const Tensor& self) {
|
|||
return map_nt(self, at::cos);
|
||||
}
|
||||
|
||||
Tensor _pin_memory_nested(const Tensor& self, c10::optional<Device> device) {
|
||||
auto* nt_input = get_nested_tensor_impl(self);
|
||||
const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();
|
||||
return wrap_buffer(
|
||||
at::_pin_memory(input_buffer, device),
|
||||
nt_input->get_nested_sizes(),
|
||||
nt_input->get_nested_strides(),
|
||||
nt_input->get_storage_offsets());
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ bool is_pinned(const Tensor& self, c10::optional<at::Device> device) {
|
|||
at::Tensor _pin_memory(const Tensor& self, c10::optional<at::Device> device) {
|
||||
TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
|
||||
DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(at::kCUDA)));
|
||||
if (self.is_nested()) {
|
||||
constexpr auto nested_key_set = c10::DispatchKeySet(
|
||||
{c10::DispatchKey::NestedTensor, c10::DispatchKey::AutogradNestedTensor});
|
||||
_dk = _dk.add(self.key_set() & nested_key_set);
|
||||
}
|
||||
return at::_ops::_pin_memory::redispatch(_dk, self, device);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2282,16 +2282,22 @@ class TestDataLoaderDeviceType(TestCase):
|
|||
|
||||
dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(10)]
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
collate_fn=_clone_collate,
|
||||
multiprocessing_context=context,
|
||||
)
|
||||
pin_memory_settings = [False]
|
||||
if device == 'cpu' and torch.cuda.is_available():
|
||||
pin_memory_settings.append(True)
|
||||
|
||||
for i, batch in enumerate(loader):
|
||||
self.assertEqual(batch[0], dataset[i])
|
||||
for pin_memory in pin_memory_settings:
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
collate_fn=_clone_collate,
|
||||
pin_memory=pin_memory,
|
||||
multiprocessing_context=context,
|
||||
)
|
||||
|
||||
for i, batch in enumerate(loader):
|
||||
self.assertEqual(batch[0], dataset[i])
|
||||
|
||||
# Error case: default collate_fn doesn't currently support batches of nested tensors.
|
||||
# Following the current semantics, we'd need to stack them, which isn't possible atm.
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from torch.testing._internal.common_device_type import (
|
|||
onlyCPU,
|
||||
onlyCUDA,
|
||||
skipMeta,
|
||||
PYTORCH_CUDA_MEMCHECK,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import floating_types_and_half
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -2941,6 +2942,20 @@ class TestNestedTensorSubclass(TestCase):
|
|||
self.assertEqual(b, nt_contiguous)
|
||||
self.assertEqual(b, nt_noncontiguous)
|
||||
|
||||
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
|
||||
@onlyCUDA
|
||||
def test_pin_memory(self, device):
|
||||
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
||||
for nt in [nt_contiguous, nt_noncontiguous]:
|
||||
self.assertFalse(nt.is_pinned())
|
||||
pinned = nt.pin_memory(device)
|
||||
self.assertTrue(pinned.is_pinned())
|
||||
self.assertEqual(nt, pinned)
|
||||
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
|
||||
# test that pin_memory on already pinned tensor has no effect
|
||||
self.assertIs(pinned, pinned.pin_memory())
|
||||
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user