mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS][BE] Do not copy sizes/strides unnecesserily (#154670)
Just pass them as args to `mtl_setArgs`, metaprogramming should deal with the rest Also use `mtl_dispatch1DJob` instead of computing max threadgroup size by nand Pull Request resolved: https://github.com/pytorch/pytorch/pull/154670 Approved by: https://github.com/dcci
This commit is contained in:
parent
61bfb3df9f
commit
0134150ebb
|
|
@ -311,25 +311,16 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
|
|||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
[computeEncoder setComputePipelineState:indexCopyPSO];
|
||||
mtl_setArgs(computeEncoder, result, self, source, index);
|
||||
std::vector<int64_t> sizes_buffer(self.sizes().begin(), self.sizes().end());
|
||||
auto computeEncoder = stream->commandEncoder();
|
||||
uint32_t dim_arg = static_cast<uint32_t>(dim);
|
||||
uint32_t ndim = self.dim();
|
||||
uint32_t indices_numel = index.numel();
|
||||
int numThreads = result.numel();
|
||||
mtl_setArgs<4>(computeEncoder, dim_arg, sizes_buffer, ndim, indices_numel);
|
||||
[computeEncoder setComputePipelineState:indexCopyPSO];
|
||||
mtl_setArgs(computeEncoder, result, self, source, index, dim_arg, self.sizes(), ndim, indices_numel);
|
||||
if (!is_dense) {
|
||||
std::vector<int64_t> input_strides_buffer(self.strides().begin(), self.strides().end());
|
||||
std::vector<int64_t> output_strides_buffer(result.strides().begin(), result.strides().end());
|
||||
std::vector<int64_t> source_strides_buffer(source.strides().begin(), source.strides().end());
|
||||
mtl_setArgs<8>(computeEncoder, input_strides_buffer, output_strides_buffer, source_strides_buffer);
|
||||
mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides());
|
||||
}
|
||||
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
|
||||
MTLSize threadgroupSize =
|
||||
MTLSizeMake(std::min(numThreads, static_cast<int32_t>(indexCopyPSO.maxTotalThreadsPerThreadgroup)), 1, 1);
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize];
|
||||
mtl_dispatch1DJob(computeEncoder, indexCopyPSO, result.numel());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user