mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
pin_memory support for NT (#110404)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110404 Approved by: https://github.com/cpuhrsch, https://github.com/albanD ghstack dependencies: #110292
This commit is contained in:
parent
cc1de49340
commit
3597325bc7
|
|
@ -4405,7 +4405,7 @@
|
||||||
- func: is_pinned(Tensor self, Device? device=None) -> bool
|
- func: is_pinned(Tensor self, Device? device=None) -> bool
|
||||||
variants: method
|
variants: method
|
||||||
dispatch:
|
dispatch:
|
||||||
CUDA: is_pinned_cuda
|
NestedTensorCUDA, CUDA: is_pinned_cuda
|
||||||
MPS: is_pinned_mps
|
MPS: is_pinned_mps
|
||||||
CompositeExplicitAutograd: is_pinned_default
|
CompositeExplicitAutograd: is_pinned_default
|
||||||
|
|
||||||
|
|
@ -4419,6 +4419,7 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CUDA: _pin_memory_cuda
|
CUDA: _pin_memory_cuda
|
||||||
MPS: _pin_memory_mps
|
MPS: _pin_memory_mps
|
||||||
|
NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested
|
||||||
autogen: _pin_memory.out
|
autogen: _pin_memory.out
|
||||||
|
|
||||||
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
|
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
|
||||||
|
|
|
||||||
|
|
@ -132,5 +132,15 @@ Tensor cos_nested(const Tensor& self) {
|
||||||
return map_nt(self, at::cos);
|
return map_nt(self, at::cos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor _pin_memory_nested(const Tensor& self, c10::optional<Device> device) {
|
||||||
|
auto* nt_input = get_nested_tensor_impl(self);
|
||||||
|
const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();
|
||||||
|
return wrap_buffer(
|
||||||
|
at::_pin_memory(input_buffer, device),
|
||||||
|
nt_input->get_nested_sizes(),
|
||||||
|
nt_input->get_nested_strides(),
|
||||||
|
nt_input->get_storage_offsets());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,11 @@ bool is_pinned(const Tensor& self, c10::optional<at::Device> device) {
|
||||||
at::Tensor _pin_memory(const Tensor& self, c10::optional<at::Device> device) {
|
at::Tensor _pin_memory(const Tensor& self, c10::optional<at::Device> device) {
|
||||||
TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
|
TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
|
||||||
DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(at::kCUDA)));
|
DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(at::kCUDA)));
|
||||||
|
if (self.is_nested()) {
|
||||||
|
constexpr auto nested_key_set = c10::DispatchKeySet(
|
||||||
|
{c10::DispatchKey::NestedTensor, c10::DispatchKey::AutogradNestedTensor});
|
||||||
|
_dk = _dk.add(self.key_set() & nested_key_set);
|
||||||
|
}
|
||||||
return at::_ops::_pin_memory::redispatch(_dk, self, device);
|
return at::_ops::_pin_memory::redispatch(_dk, self, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2282,11 +2282,17 @@ class TestDataLoaderDeviceType(TestCase):
|
||||||
|
|
||||||
dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(10)]
|
dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(10)]
|
||||||
|
|
||||||
|
pin_memory_settings = [False]
|
||||||
|
if device == 'cpu' and torch.cuda.is_available():
|
||||||
|
pin_memory_settings.append(True)
|
||||||
|
|
||||||
|
for pin_memory in pin_memory_settings:
|
||||||
loader = torch.utils.data.DataLoader(
|
loader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
collate_fn=_clone_collate,
|
collate_fn=_clone_collate,
|
||||||
|
pin_memory=pin_memory,
|
||||||
multiprocessing_context=context,
|
multiprocessing_context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from torch.testing._internal.common_device_type import (
|
||||||
onlyCPU,
|
onlyCPU,
|
||||||
onlyCUDA,
|
onlyCUDA,
|
||||||
skipMeta,
|
skipMeta,
|
||||||
|
PYTORCH_CUDA_MEMCHECK,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_dtype import floating_types_and_half
|
from torch.testing._internal.common_dtype import floating_types_and_half
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
|
@ -2941,6 +2942,20 @@ class TestNestedTensorSubclass(TestCase):
|
||||||
self.assertEqual(b, nt_contiguous)
|
self.assertEqual(b, nt_contiguous)
|
||||||
self.assertEqual(b, nt_noncontiguous)
|
self.assertEqual(b, nt_noncontiguous)
|
||||||
|
|
||||||
|
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
|
||||||
|
@onlyCUDA
|
||||||
|
def test_pin_memory(self, device):
|
||||||
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
||||||
|
for nt in [nt_contiguous, nt_noncontiguous]:
|
||||||
|
self.assertFalse(nt.is_pinned())
|
||||||
|
pinned = nt.pin_memory(device)
|
||||||
|
self.assertTrue(pinned.is_pinned())
|
||||||
|
self.assertEqual(nt, pinned)
|
||||||
|
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
|
||||||
|
# test that pin_memory on already pinned tensor has no effect
|
||||||
|
self.assertIs(pinned, pinned.pin_memory())
|
||||||
|
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestNestedTensor)
|
instantiate_parametrized_tests(TestNestedTensor)
|
||||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user