diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index a2ec221c1bf..34fbd31af91 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -70,7 +70,10 @@ void MPSHooks::commitStream() const { } void* MPSHooks::getCommandBuffer() const { - return at::mps::getDefaultMPSStream()->commandBuffer(); + auto stream = at::mps::getDefaultMPSStream(); + // Release pending computeCommandEncoder, as extensions is likely to allocate new one + stream->endKernelCoalescing(); + return stream->commandBuffer(); } void* MPSHooks::getDispatchQueue() const { diff --git a/test/cpp_extensions/mps_extension.mm b/test/cpp_extensions/mps_extension.mm index 882e5c5603e..30b70a76563 100644 --- a/test/cpp_extensions/mps_extension.mm +++ b/test/cpp_extensions/mps_extension.mm @@ -13,6 +13,11 @@ kernel void add_arrays(device const float* inA, { result[index] = inA[index] + inB[index]; } + +kernel void add_one(device float* data, + uint index [[thread_position_in_grid]]) { + data[index] += 1.0; +} )MPS_ADD_ARRAYS"); at::Tensor get_cpu_add_output(at::Tensor & cpu_input1, at::Tensor & cpu_input2) { @@ -50,7 +55,31 @@ at::Tensor get_mps_add_output(at::Tensor & mps_input1, at::Tensor & mps_input2) return mps_output; } +void mps_add_one_new_encoder(const at::Tensor& input) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps()); + TORCH_CHECK(input.numel() > 0); + + @autoreleasepool { + auto kernelPSO = lib.getPipelineStateForFunc("add_one"); + auto serialQueue = torch::mps::get_dispatch_queue(); + + dispatch_sync(serialQueue, ^(){ + auto commandBuffer = torch::mps::get_command_buffer(); + // Start a compute pass. + auto computeEncoder = [commandBuffer computeCommandEncoder]; + TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); + [computeEncoder setComputePipelineState: kernelPSO]; + mtl_setArgs(computeEncoder, input); + mtl_dispatch1DJob(computeEncoder, kernelPSO, input.numel()); + [computeEncoder endEncoding]; + torch::mps::commit(); + }); + } +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_cpu_add_output", &get_cpu_add_output); m.def("get_mps_add_output", &get_mps_add_output); + m.def("mps_add_one_new_context", &mps_add_one_new_encoder); } diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index fd80c7fa565..e93167296a0 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -220,6 +220,12 @@ class TestCppExtensionJIT(common.TestCase): self.assertEqual(cpu_output, mps_output.to("cpu")) + # Regression test for https://github.com/pytorch/pytorch/issues/163721 + lib = torch.mps.compile_shader("void kernel noop(device float *x) {}") + lib.noop(mps_output) + module.mps_add_one_new_context(mps_output) + self.assertEqual(cpu_output + 1.0, mps_output.to("cpu")) + def _run_jit_cuda_archflags(self, flags, expected): # Compile an extension with given `flags` def _check_cuobjdump_output(expected_values, is_ptx=False):