mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS] Add testcase for copying cpu tensors into strided mps tensors (#91784)
Fixes https://github.com/pytorch/pytorch/issues/86975 If the destination is a strided MPS tensor and the source is a CPU tensor, we cannot perform a blit directly to copy the memory from the CPU tensor into the MPS tensor. We need to scatter the data into the right indices. ``` a1 = torch.Tensor([[1,2],[3,4], [5,6]]).to(torch.device("mps")) b1 = torch.Tensor([-1, -1]) a1[1:,1] = b1 # strided MPS destination / contiguous CPU source ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91784 Approved by: https://github.com/kulinseth
This commit is contained in:
parent
09c2b2af53
commit
0a677f2335
|
|
@ -206,11 +206,11 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo
|
|||
NSUInteger sourceOffset = 0;
|
||||
|
||||
void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength);
|
||||
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
|
||||
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
|
||||
length:alignedLength
|
||||
options:options
|
||||
deallocator:nil];
|
||||
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
|
||||
|
||||
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
|
||||
[sourceBuffer release];
|
||||
|
|
|
|||
|
|
@ -34,8 +34,7 @@ static std::string getStridedKey(const ScalarType& self_dtype, const ScalarType&
|
|||
|
||||
// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op
|
||||
static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output,
|
||||
bool needsScatter, bool requires_sync = false)
|
||||
{
|
||||
bool needsScatter, bool requires_sync = false) {
|
||||
const id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
|
||||
const id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
|
||||
|
||||
|
|
@ -721,8 +720,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst)
|
|||
return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false, requires_sync);
|
||||
}
|
||||
|
||||
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output)
|
||||
{
|
||||
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) {
|
||||
ViewCachedGraph* cachedGraph = createViewGraph(output, src, output.sizes(), output.strides(),
|
||||
output.storage_offset(), /*needsScatter*/ true);
|
||||
return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true, /*requires_sync*/ true);
|
||||
|
|
@ -731,8 +729,7 @@ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output)
|
|||
} // namespace mps
|
||||
|
||||
// implementation of as_strided() op
|
||||
Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset_)
|
||||
{
|
||||
Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional<int64_t> storage_offset_) {
|
||||
auto storage_offset = storage_offset_.value_or(self.storage_offset());
|
||||
auto result = detail::make_tensor<TensorImpl>(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
|
||||
setStrided(result, size, stride, storage_offset);
|
||||
|
|
|
|||
|
|
@ -1598,6 +1598,19 @@ class TestMPS(TestCase):
|
|||
|
||||
self.assertEqual(x_cpu, x.cpu())
|
||||
|
||||
def test_cpu_to_strided_mps_copy(self):
|
||||
# https://github.com/pytorch/pytorch/issues/86975
|
||||
|
||||
a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
|
||||
b1 = torch.Tensor([-1, -1])
|
||||
a1[1:, 1] = b1
|
||||
|
||||
a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
|
||||
b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
|
||||
a2[1:, 1] = b2
|
||||
|
||||
self.assertEqual(a1, a2)
|
||||
|
||||
def test_view_slice(self):
|
||||
# https://github.com/pytorch/pytorch/issues/83995
|
||||
NUM_SAMPLES = 60
|
||||
|
|
@ -6990,7 +7003,6 @@ class TestViewOpsMPS(TestCase):
|
|||
self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
|
||||
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
|
||||
|
||||
# RuntimeError: Invalid device for storage: mps
|
||||
def test_contiguous(self, device="mps"):
|
||||
x = torch.randn(1, 16, 5, 5, device=device)
|
||||
self.assertTrue(x.is_contiguous())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user