[MPS][BE] Do not copy sizes/strides unnecesserily (#154670)

Just pass them as args to `mtl_setArgs`, metaprogramming should deal with the rest
Also use `mtl_dispatch1DJob` instead of computing max threadgroup size by nand

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154670
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-05-29 15:48:50 -07:00 committed by PyTorch MergeBot
parent 61bfb3df9f
commit 0134150ebb

View File

@ -311,25 +311,16 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
[computeEncoder setComputePipelineState:indexCopyPSO];
mtl_setArgs(computeEncoder, result, self, source, index);
std::vector<int64_t> sizes_buffer(self.sizes().begin(), self.sizes().end());
auto computeEncoder = stream->commandEncoder();
uint32_t dim_arg = static_cast<uint32_t>(dim);
uint32_t ndim = self.dim();
uint32_t indices_numel = index.numel();
int numThreads = result.numel();
mtl_setArgs<4>(computeEncoder, dim_arg, sizes_buffer, ndim, indices_numel);
[computeEncoder setComputePipelineState:indexCopyPSO];
mtl_setArgs(computeEncoder, result, self, source, index, dim_arg, self.sizes(), ndim, indices_numel);
if (!is_dense) {
std::vector<int64_t> input_strides_buffer(self.strides().begin(), self.strides().end());
std::vector<int64_t> output_strides_buffer(result.strides().begin(), result.strides().end());
std::vector<int64_t> source_strides_buffer(source.strides().begin(), source.strides().end());
mtl_setArgs<8>(computeEncoder, input_strides_buffer, output_strides_buffer, source_strides_buffer);
mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides());
}
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
MTLSize threadgroupSize =
MTLSizeMake(std::min(numThreads, static_cast<int32_t>(indexCopyPSO.maxTotalThreadsPerThreadgroup)), 1, 1);
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize];
mtl_dispatch1DJob(computeEncoder, indexCopyPSO, result.numel());
}
});
}