pin_memory support for NT (#110404)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110404
Approved by: https://github.com/cpuhrsch, https://github.com/albanD
ghstack dependencies: #110292
This commit is contained in:
Joel Schlosser 2023-10-04 17:43:48 -04:00 committed by PyTorch MergeBot
parent cc1de49340
commit 3597325bc7
5 changed files with 47 additions and 10 deletions

View File

@ -4405,7 +4405,7 @@
- func: is_pinned(Tensor self, Device? device=None) -> bool - func: is_pinned(Tensor self, Device? device=None) -> bool
variants: method variants: method
dispatch: dispatch:
CUDA: is_pinned_cuda NestedTensorCUDA, CUDA: is_pinned_cuda
MPS: is_pinned_mps MPS: is_pinned_mps
CompositeExplicitAutograd: is_pinned_default CompositeExplicitAutograd: is_pinned_default
@ -4419,6 +4419,7 @@
dispatch: dispatch:
CUDA: _pin_memory_cuda CUDA: _pin_memory_cuda
MPS: _pin_memory_mps MPS: _pin_memory_mps
NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested
autogen: _pin_memory.out autogen: _pin_memory.out
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor - func: pinverse(Tensor self, float rcond=1e-15) -> Tensor

View File

@ -132,5 +132,15 @@ Tensor cos_nested(const Tensor& self) {
return map_nt(self, at::cos); return map_nt(self, at::cos);
} }
Tensor _pin_memory_nested(const Tensor& self, c10::optional<Device> device) {
auto* nt_input = get_nested_tensor_impl(self);
const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();
return wrap_buffer(
at::_pin_memory(input_buffer, device),
nt_input->get_nested_sizes(),
nt_input->get_nested_strides(),
nt_input->get_storage_offsets());
}
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -36,6 +36,11 @@ bool is_pinned(const Tensor& self, c10::optional<at::Device> device) {
at::Tensor _pin_memory(const Tensor& self, c10::optional<at::Device> device) { at::Tensor _pin_memory(const Tensor& self, c10::optional<at::Device> device) {
TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned"); TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(at::kCUDA))); DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(at::kCUDA)));
if (self.is_nested()) {
constexpr auto nested_key_set = c10::DispatchKeySet(
{c10::DispatchKey::NestedTensor, c10::DispatchKey::AutogradNestedTensor});
_dk = _dk.add(self.key_set() & nested_key_set);
}
return at::_ops::_pin_memory::redispatch(_dk, self, device); return at::_ops::_pin_memory::redispatch(_dk, self, device);
} }

View File

@ -2282,16 +2282,22 @@ class TestDataLoaderDeviceType(TestCase):
dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(10)] dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(10)]
loader = torch.utils.data.DataLoader( pin_memory_settings = [False]
dataset, if device == 'cpu' and torch.cuda.is_available():
batch_size=1, pin_memory_settings.append(True)
num_workers=4,
collate_fn=_clone_collate,
multiprocessing_context=context,
)
for i, batch in enumerate(loader): for pin_memory in pin_memory_settings:
self.assertEqual(batch[0], dataset[i]) loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
collate_fn=_clone_collate,
pin_memory=pin_memory,
multiprocessing_context=context,
)
for i, batch in enumerate(loader):
self.assertEqual(batch[0], dataset[i])
# Error case: default collate_fn doesn't currently support batches of nested tensors. # Error case: default collate_fn doesn't currently support batches of nested tensors.
# Following the current semantics, we'd need to stack them, which isn't possible atm. # Following the current semantics, we'd need to stack them, which isn't possible atm.

View File

@ -15,6 +15,7 @@ from torch.testing._internal.common_device_type import (
onlyCPU, onlyCPU,
onlyCUDA, onlyCUDA,
skipMeta, skipMeta,
PYTORCH_CUDA_MEMCHECK,
) )
from torch.testing._internal.common_dtype import floating_types_and_half from torch.testing._internal.common_dtype import floating_types_and_half
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
@ -2941,6 +2942,20 @@ class TestNestedTensorSubclass(TestCase):
self.assertEqual(b, nt_contiguous) self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous) self.assertEqual(b, nt_noncontiguous)
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
@onlyCUDA
def test_pin_memory(self, device):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
for nt in [nt_contiguous, nt_noncontiguous]:
self.assertFalse(nt.is_pinned())
pinned = nt.pin_memory(device)
self.assertTrue(pinned.is_pinned())
self.assertEqual(nt, pinned)
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
# test that pin_memory on already pinned tensor has no effect
self.assertIs(pinned, pinned.pin_memory())
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
instantiate_parametrized_tests(TestNestedTensor) instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) instantiate_device_type_tests(TestNestedTensorDeviceType, globals())