[aoti][mps] mps constants support (#154287)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154287
Approved by: https://github.com/malfet
ghstack dependencies: #155752
This commit is contained in:
angelayi 2025-06-12 11:51:51 -07:00 committed by PyTorch MergeBot
parent 8821a9dc4e
commit a4ab392251
13 changed files with 126 additions and 4 deletions

View File

@ -78,6 +78,9 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
virtual uint32_t acquireEvent(bool enable_timing) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
Device getDeviceFromPtr(void* data) const override {
TORCH_CHECK(false, "Cannot get device of pointer on MPS without ATen_mps library. ");
}
virtual void releaseEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}

View File

@ -1,6 +1,7 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <ATen/Device.h>
#include <c10/core/Allocator.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
@ -70,4 +71,8 @@ TORCH_API bool is_available();
TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
inline Device getDeviceFromPtr(void* ptr) {
return {c10::DeviceType::MPS, 0};
}
} // namespace at::mps

View File

@ -18,6 +18,8 @@ struct MPSHooks : public at::MPSHooksInterface {
bool hasMPS() const override;
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
Device getDeviceFromPtr(void* data) const override;
// MPSGeneratorImpl interface
const Generator& getDefaultGenerator(
DeviceIndex device_index = -1) const override;

View File

@ -129,6 +129,10 @@ void MPSHooks::recordEvent(uint32_t event_id) const {
at::mps::getMPSEventPool()->recordEvent(event_id, /* syncEvent*/ true);
}
Device MPSHooks::getDeviceFromPtr(void* data) const {
return at::mps::getDeviceFromPtr(data);
}
void MPSHooks::waitForEvent(uint32_t event_id) const {
at::mps::getMPSEventPool()->waitForEvent(event_id, /* syncEvent*/ true);
}

View File

@ -709,6 +709,7 @@ list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS})
if(USE_MPS)
list(APPEND Caffe2_CPU_SRCS ${Caffe2_MPS_SRCS})
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.cpp)
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.mm)
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_runner/model_container_runner_mps.cpp)
if(CAN_COMPILE_METAL)
file(TOUCH ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp)

View File

@ -223,6 +223,20 @@ class MPSBasicTestsAOTI(TestCase):
m = M().to("mps")
self.check_model(m, inp)
def test_two_const(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.y = torch.ones(3, 3, device="mps")
self.z = torch.full((3, 3), 2, device="mps")
def forward(self, x):
return x + self.y + self.z
inp = (torch.ones(3, 3, device="mps"),)
m = Model().to(device="mps")
self.check_model(m, inp)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1313,6 +1313,9 @@ def get_cpp_torch_device_options(
"in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support."
)
if device_type == "mps":
definitions.append(" USE_MPS")
if config.is_fbcode():
include_dirs.append(build_paths.sdk_include)

View File

@ -15,11 +15,14 @@
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
// applies to other files under torch/csrc/inductor/aoti_runtime/.
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
#ifdef USE_MPS
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
#endif // USE_MPS
#ifdef USE_XPU
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
#else
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#endif
#endif // USE_XPU
#include <torch/csrc/inductor/aoti_runtime/constant_type.h>
#define AOTI_RUNTIME_CHECK(EXPR, MSG) \
@ -74,6 +77,15 @@ RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) {
return RAIIDataPtr(data_ptr, deleter);
}
#elif defined(USE_MPS)
RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) {
void* data_ptr = nullptr;
aoti_torch_mps_malloc(&data_ptr, num_bytes);
auto deleter = [](void* ptr) { aoti_torch_mps_free(ptr); };
return RAIIDataPtr(data_ptr, deleter);
}
#else
RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) {
@ -113,7 +125,7 @@ inline void parse_device_str(
} else if (sm[1].str() == "xpu") {
device_type = aoti_torch_device_type_xpu();
#endif
#ifdef __APPLE__
#ifdef USE_MPS
} else if (sm[1].str() == "mps") {
device_type = aoti_torch_device_type_mps();
#endif
@ -165,6 +177,11 @@ class AOTInductorModelBase {
aoti_torch_set_current_xpu_device(device_idx_);
}
#endif // USE_XPU
#ifdef USE_MPS
if (device_idx_ == -1) {
device_idx_ = 0;
}
#endif // USE_MPS
}
// NOLINTNEXTLINE(modernize-use-equals-default)
@ -299,7 +316,7 @@ class AOTInductorModelBase {
if (!include_weights) {
return;
}
#if defined(USE_CUDA) || defined(USE_XPU)
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
constant_blob_ = RAII_gpuMalloc(blob_size);
#else
constant_blob_ = RAII_cpuMalloc(blob_size);
@ -327,7 +344,12 @@ class AOTInductorModelBase {
auto ndim = this->constant_ndim(i);
auto size = this->constant_shape(i);
auto stride = this->constant_stride(i);
#ifdef USE_MPS
auto offset = this->constant_offset(i) +
(constants_internal_offset[i] / aoti_torch_dtype_element_size(dtype));
#else
auto offset = this->constant_offset(i);
#endif
auto layout = this->constant_layout(i);
auto opaque_metadata_ptr = this->opaque_metadata(i);
auto opaque_metadata_size = this->opaque_metadata_size(i);
@ -390,6 +412,14 @@ class AOTInductorModelBase {
_get_constants_start() + bytes_read,
data_size,
cudaMemcpyHostToDevice));
#elif USE_MPS
aoti_torch_mps_memcpy(
constants_ptr,
constant_offset,
bytes_read,
data_size,
_get_constants_start());
return constants_ptr;
#else
memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size);
#endif

View File

@ -666,7 +666,7 @@ class AOTInductorModelContainer {
std::shared_mutex model_exec_mutex_;
RAIIDataPtr allocate_constant_blob() {
#if defined(USE_CUDA) || defined(USE_XPU)
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
return RAII_gpuMalloc(blob_size_);
#else
return RAII_cpuMalloc(blob_size_);

View File

@ -129,6 +129,7 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128();
AOTI_TORCH_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype);
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided();
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_coo();

View File

@ -15,6 +15,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg(
unsigned idx,
AtenTensorHandle tensor);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_mps_malloc(void** buffer, size_t num_bytes);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_free(void* ptr);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_memcpy(
void* buffer,
size_t constant_offset,
size_t bytes_read,
size_t data_size,
uint8_t* constants_start);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -253,6 +253,11 @@ void aoti_torch_grad_mode_set_enabled(bool enabled) {
return c10::GradMode::set_enabled(enabled);
}
size_t aoti_torch_dtype_element_size(int32_t dtype) {
auto scalar_type = static_cast<at::ScalarType>(dtype);
return c10::elementSize(scalar_type);
}
AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);

View File

@ -0,0 +1,42 @@
#include <ATen/native/mps/MetalShaderLibrary.h>
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSDevice.h>
using namespace torch::aot_inductor;
AOTITorchError aoti_torch_mps_malloc(
void** buffer,
size_t num_bytes) {
if (num_bytes == 0) {
*buffer = nullptr;
return AOTI_TORCH_SUCCESS;
}
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
id<MTLDevice> device = at::mps::MPSDevice::getInstance()->device();
TORCH_CHECK(device, "Failed to get MPS device");
id<MTLBuffer> metal_buffer = [device newBufferWithLength:num_bytes options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared];
TORCH_CHECK(metal_buffer, "Failed to allocate memory on MPS device");
*buffer = (void*)metal_buffer;
});
}
AOTITorchError aoti_torch_mps_free(
void* ptr) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto metal_buffer = (id<MTLBuffer>)ptr;
[metal_buffer release];
});
}
AOTITorchError
aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto metal_buffer = (id<MTLBuffer>)buffer;
auto buffer_pointer = static_cast<uint8_t*>([metal_buffer contents]);
memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size);
});
}