[MPS] Add testcase for copying cpu tensors into strided mps tensors (#91784)

Fixes https://github.com/pytorch/pytorch/issues/86975

If the destination is a strided MPS tensor and the source is a CPU tensor, we cannot perform a blit directly to copy the memory from the CPU tensor into the MPS tensor. We need to scatter the data into the right indices.
```
        a1 = torch.Tensor([[1,2],[3,4], [5,6]]).to(torch.device("mps"))
        b1 = torch.Tensor([-1, -1])
        a1[1:,1] = b1  # strided MPS destination / contiguous CPU source
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91784
Approved by: https://github.com/kulinseth
This commit is contained in:
Denis Vieriu 2023-01-10 22:45:48 +00:00 committed by PyTorch MergeBot
parent 09c2b2af53
commit 0a677f2335
3 changed files with 17 additions and 8 deletions

View File

@ -206,11 +206,11 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo
NSUInteger sourceOffset = 0;
void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength);
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
deallocator:nil];
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
[sourceBuffer release];

View File

@ -34,8 +34,7 @@ static std::string getStridedKey(const ScalarType& self_dtype, const ScalarType&
// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op
static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output,
bool needsScatter, bool requires_sync = false)
{
bool needsScatter, bool requires_sync = false) {
const id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
const id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
@ -721,8 +720,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst)
return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false, requires_sync);
}
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output)
{
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) {
ViewCachedGraph* cachedGraph = createViewGraph(output, src, output.sizes(), output.strides(),
output.storage_offset(), /*needsScatter*/ true);
return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true, /*requires_sync*/ true);
@ -731,8 +729,7 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output)
} // namespace mps
// implementation of as_strided() op
Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset_)
{
Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset_) {
auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto result = detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
setStrided(result, size, stride, storage_offset);

View File

@ -1598,6 +1598,19 @@ class TestMPS(TestCase):
self.assertEqual(x_cpu, x.cpu())
def test_cpu_to_strided_mps_copy(self):
# https://github.com/pytorch/pytorch/issues/86975
a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
b1 = torch.Tensor([-1, -1])
a1[1:, 1] = b1
a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
a2[1:, 1] = b2
self.assertEqual(a1, a2)
def test_view_slice(self):
# https://github.com/pytorch/pytorch/issues/83995
NUM_SAMPLES = 60
@ -6990,7 +7003,6 @@ class TestViewOpsMPS(TestCase):
self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
# RuntimeError: Invalid device for storage: mps
def test_contiguous(self, device="mps"):
x = torch.randn(1, 16, 5, 5, device=device)
self.assertTrue(x.is_contiguous())