diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index a73e4180317..1cda971fa50 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -206,11 +206,11 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo NSUInteger sourceOffset = 0; void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength); + sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); id sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr length:alignedLength options:options deallocator:nil]; - sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); [sourceBuffer release]; diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index 372f584ad2b..8a5a95ce328 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -34,8 +34,7 @@ static std::string getStridedKey(const ScalarType& self_dtype, const ScalarType& // initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output, - bool needsScatter, bool requires_sync = false) -{ + bool needsScatter, bool requires_sync = false) { const id sourceBuffer = getMTLBufferStorage(src); const id outputBuffer = getMTLBufferStorage(output); @@ -721,8 +720,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false, requires_sync); } -Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) -{ +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) { ViewCachedGraph* cachedGraph = createViewGraph(output, src, output.sizes(), output.strides(), output.storage_offset(), /*needsScatter*/ true); return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true, /*requires_sync*/ true); @@ -731,8 +729,7 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) } // namespace mps // implementation of as_strided() op -Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional storage_offset_) -{ +Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional storage_offset_) { auto storage_offset = storage_offset_.value_or(self.storage_offset()); auto result = detail::make_tensor(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); setStrided(result, size, stride, storage_offset); diff --git a/test/test_mps.py b/test/test_mps.py index a39ffc6a871..b7eb6ea8904 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1598,6 +1598,19 @@ class TestMPS(TestCase): self.assertEqual(x_cpu, x.cpu()) + def test_cpu_to_strided_mps_copy(self): + # https://github.com/pytorch/pytorch/issues/86975 + + a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps")) + b1 = torch.Tensor([-1, -1]) + a1[1:, 1] = b1 + + a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps")) + b2 = torch.Tensor([-1, -1]).to(torch.device("mps")) + a2[1:, 1] = b2 + + self.assertEqual(a1, a2) + def test_view_slice(self): # https://github.com/pytorch/pytorch/issues/83995 NUM_SAMPLES = 60 @@ -6990,7 +7003,6 @@ class TestViewOpsMPS(TestCase): self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) - # RuntimeError: Invalid device for storage: mps def test_contiguous(self, device="mps"): x = torch.randn(1, 16, 5, 5, device=device) self.assertTrue(x.is_contiguous())