mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aoti][mps] mps constants support (#154287)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154287 Approved by: https://github.com/malfet ghstack dependencies: #155752
This commit is contained in:
parent
8821a9dc4e
commit
a4ab392251
|
|
@ -78,6 +78,9 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
|
||||||
virtual uint32_t acquireEvent(bool enable_timing) const {
|
virtual uint32_t acquireEvent(bool enable_timing) const {
|
||||||
FAIL_MPSHOOKS_FUNC(__func__);
|
FAIL_MPSHOOKS_FUNC(__func__);
|
||||||
}
|
}
|
||||||
|
Device getDeviceFromPtr(void* data) const override {
|
||||||
|
TORCH_CHECK(false, "Cannot get device of pointer on MPS without ATen_mps library. ");
|
||||||
|
}
|
||||||
virtual void releaseEvent(uint32_t event_id) const {
|
virtual void releaseEvent(uint32_t event_id) const {
|
||||||
FAIL_MPSHOOKS_FUNC(__func__);
|
FAIL_MPSHOOKS_FUNC(__func__);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
// Copyright © 2022 Apple Inc.
|
// Copyright © 2022 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <ATen/Device.h>
|
||||||
#include <c10/core/Allocator.h>
|
#include <c10/core/Allocator.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
@ -70,4 +71,8 @@ TORCH_API bool is_available();
|
||||||
TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
|
TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
|
||||||
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
||||||
|
|
||||||
|
inline Device getDeviceFromPtr(void* ptr) {
|
||||||
|
return {c10::DeviceType::MPS, 0};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace at::mps
|
} // namespace at::mps
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ struct MPSHooks : public at::MPSHooksInterface {
|
||||||
bool hasMPS() const override;
|
bool hasMPS() const override;
|
||||||
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
|
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
|
||||||
|
|
||||||
|
Device getDeviceFromPtr(void* data) const override;
|
||||||
|
|
||||||
// MPSGeneratorImpl interface
|
// MPSGeneratorImpl interface
|
||||||
const Generator& getDefaultGenerator(
|
const Generator& getDefaultGenerator(
|
||||||
DeviceIndex device_index = -1) const override;
|
DeviceIndex device_index = -1) const override;
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,10 @@ void MPSHooks::recordEvent(uint32_t event_id) const {
|
||||||
at::mps::getMPSEventPool()->recordEvent(event_id, /* syncEvent*/ true);
|
at::mps::getMPSEventPool()->recordEvent(event_id, /* syncEvent*/ true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Device MPSHooks::getDeviceFromPtr(void* data) const {
|
||||||
|
return at::mps::getDeviceFromPtr(data);
|
||||||
|
}
|
||||||
|
|
||||||
void MPSHooks::waitForEvent(uint32_t event_id) const {
|
void MPSHooks::waitForEvent(uint32_t event_id) const {
|
||||||
at::mps::getMPSEventPool()->waitForEvent(event_id, /* syncEvent*/ true);
|
at::mps::getMPSEventPool()->waitForEvent(event_id, /* syncEvent*/ true);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -709,6 +709,7 @@ list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS})
|
||||||
if(USE_MPS)
|
if(USE_MPS)
|
||||||
list(APPEND Caffe2_CPU_SRCS ${Caffe2_MPS_SRCS})
|
list(APPEND Caffe2_CPU_SRCS ${Caffe2_MPS_SRCS})
|
||||||
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.cpp)
|
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.cpp)
|
||||||
|
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.mm)
|
||||||
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_runner/model_container_runner_mps.cpp)
|
list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_runner/model_container_runner_mps.cpp)
|
||||||
if(CAN_COMPILE_METAL)
|
if(CAN_COMPILE_METAL)
|
||||||
file(TOUCH ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp)
|
file(TOUCH ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp)
|
||||||
|
|
|
||||||
|
|
@ -223,6 +223,20 @@ class MPSBasicTestsAOTI(TestCase):
|
||||||
m = M().to("mps")
|
m = M().to("mps")
|
||||||
self.check_model(m, inp)
|
self.check_model(m, inp)
|
||||||
|
|
||||||
|
def test_two_const(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.y = torch.ones(3, 3, device="mps")
|
||||||
|
self.z = torch.full((3, 3), 2, device="mps")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.y + self.z
|
||||||
|
|
||||||
|
inp = (torch.ones(3, 3, device="mps"),)
|
||||||
|
m = Model().to(device="mps")
|
||||||
|
self.check_model(m, inp)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -1313,6 +1313,9 @@ def get_cpp_torch_device_options(
|
||||||
"in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support."
|
"in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if device_type == "mps":
|
||||||
|
definitions.append(" USE_MPS")
|
||||||
|
|
||||||
if config.is_fbcode():
|
if config.is_fbcode():
|
||||||
include_dirs.append(build_paths.sdk_include)
|
include_dirs.append(build_paths.sdk_include)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,14 @@
|
||||||
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
|
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
|
||||||
// applies to other files under torch/csrc/inductor/aoti_runtime/.
|
// applies to other files under torch/csrc/inductor/aoti_runtime/.
|
||||||
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
|
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
|
||||||
|
#ifdef USE_MPS
|
||||||
|
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
|
||||||
|
#endif // USE_MPS
|
||||||
#ifdef USE_XPU
|
#ifdef USE_XPU
|
||||||
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
|
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
|
||||||
#else
|
#else
|
||||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||||
#endif
|
#endif // USE_XPU
|
||||||
#include <torch/csrc/inductor/aoti_runtime/constant_type.h>
|
#include <torch/csrc/inductor/aoti_runtime/constant_type.h>
|
||||||
|
|
||||||
#define AOTI_RUNTIME_CHECK(EXPR, MSG) \
|
#define AOTI_RUNTIME_CHECK(EXPR, MSG) \
|
||||||
|
|
@ -74,6 +77,15 @@ RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) {
|
||||||
return RAIIDataPtr(data_ptr, deleter);
|
return RAIIDataPtr(data_ptr, deleter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#elif defined(USE_MPS)
|
||||||
|
|
||||||
|
RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) {
|
||||||
|
void* data_ptr = nullptr;
|
||||||
|
aoti_torch_mps_malloc(&data_ptr, num_bytes);
|
||||||
|
auto deleter = [](void* ptr) { aoti_torch_mps_free(ptr); };
|
||||||
|
return RAIIDataPtr(data_ptr, deleter);
|
||||||
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) {
|
RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) {
|
||||||
|
|
@ -113,7 +125,7 @@ inline void parse_device_str(
|
||||||
} else if (sm[1].str() == "xpu") {
|
} else if (sm[1].str() == "xpu") {
|
||||||
device_type = aoti_torch_device_type_xpu();
|
device_type = aoti_torch_device_type_xpu();
|
||||||
#endif
|
#endif
|
||||||
#ifdef __APPLE__
|
#ifdef USE_MPS
|
||||||
} else if (sm[1].str() == "mps") {
|
} else if (sm[1].str() == "mps") {
|
||||||
device_type = aoti_torch_device_type_mps();
|
device_type = aoti_torch_device_type_mps();
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -165,6 +177,11 @@ class AOTInductorModelBase {
|
||||||
aoti_torch_set_current_xpu_device(device_idx_);
|
aoti_torch_set_current_xpu_device(device_idx_);
|
||||||
}
|
}
|
||||||
#endif // USE_XPU
|
#endif // USE_XPU
|
||||||
|
#ifdef USE_MPS
|
||||||
|
if (device_idx_ == -1) {
|
||||||
|
device_idx_ = 0;
|
||||||
|
}
|
||||||
|
#endif // USE_MPS
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(modernize-use-equals-default)
|
// NOLINTNEXTLINE(modernize-use-equals-default)
|
||||||
|
|
@ -299,7 +316,7 @@ class AOTInductorModelBase {
|
||||||
if (!include_weights) {
|
if (!include_weights) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
|
||||||
constant_blob_ = RAII_gpuMalloc(blob_size);
|
constant_blob_ = RAII_gpuMalloc(blob_size);
|
||||||
#else
|
#else
|
||||||
constant_blob_ = RAII_cpuMalloc(blob_size);
|
constant_blob_ = RAII_cpuMalloc(blob_size);
|
||||||
|
|
@ -327,7 +344,12 @@ class AOTInductorModelBase {
|
||||||
auto ndim = this->constant_ndim(i);
|
auto ndim = this->constant_ndim(i);
|
||||||
auto size = this->constant_shape(i);
|
auto size = this->constant_shape(i);
|
||||||
auto stride = this->constant_stride(i);
|
auto stride = this->constant_stride(i);
|
||||||
|
#ifdef USE_MPS
|
||||||
|
auto offset = this->constant_offset(i) +
|
||||||
|
(constants_internal_offset[i] / aoti_torch_dtype_element_size(dtype));
|
||||||
|
#else
|
||||||
auto offset = this->constant_offset(i);
|
auto offset = this->constant_offset(i);
|
||||||
|
#endif
|
||||||
auto layout = this->constant_layout(i);
|
auto layout = this->constant_layout(i);
|
||||||
auto opaque_metadata_ptr = this->opaque_metadata(i);
|
auto opaque_metadata_ptr = this->opaque_metadata(i);
|
||||||
auto opaque_metadata_size = this->opaque_metadata_size(i);
|
auto opaque_metadata_size = this->opaque_metadata_size(i);
|
||||||
|
|
@ -390,6 +412,14 @@ class AOTInductorModelBase {
|
||||||
_get_constants_start() + bytes_read,
|
_get_constants_start() + bytes_read,
|
||||||
data_size,
|
data_size,
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
|
#elif USE_MPS
|
||||||
|
aoti_torch_mps_memcpy(
|
||||||
|
constants_ptr,
|
||||||
|
constant_offset,
|
||||||
|
bytes_read,
|
||||||
|
data_size,
|
||||||
|
_get_constants_start());
|
||||||
|
return constants_ptr;
|
||||||
#else
|
#else
|
||||||
memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size);
|
memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -666,7 +666,7 @@ class AOTInductorModelContainer {
|
||||||
std::shared_mutex model_exec_mutex_;
|
std::shared_mutex model_exec_mutex_;
|
||||||
|
|
||||||
RAIIDataPtr allocate_constant_blob() {
|
RAIIDataPtr allocate_constant_blob() {
|
||||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
|
||||||
return RAII_gpuMalloc(blob_size_);
|
return RAII_gpuMalloc(blob_size_);
|
||||||
#else
|
#else
|
||||||
return RAII_cpuMalloc(blob_size_);
|
return RAII_cpuMalloc(blob_size_);
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,7 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool();
|
||||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32();
|
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32();
|
||||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64();
|
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64();
|
||||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128();
|
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128();
|
||||||
|
AOTI_TORCH_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype);
|
||||||
|
|
||||||
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided();
|
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided();
|
||||||
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_coo();
|
AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_coo();
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg(
|
||||||
unsigned idx,
|
unsigned idx,
|
||||||
AtenTensorHandle tensor);
|
AtenTensorHandle tensor);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError
|
||||||
|
aoti_torch_mps_malloc(void** buffer, size_t num_bytes);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_free(void* ptr);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_memcpy(
|
||||||
|
void* buffer,
|
||||||
|
size_t constant_offset,
|
||||||
|
size_t bytes_read,
|
||||||
|
size_t data_size,
|
||||||
|
uint8_t* constants_start);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -253,6 +253,11 @@ void aoti_torch_grad_mode_set_enabled(bool enabled) {
|
||||||
return c10::GradMode::set_enabled(enabled);
|
return c10::GradMode::set_enabled(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t aoti_torch_dtype_element_size(int32_t dtype) {
|
||||||
|
auto scalar_type = static_cast<at::ScalarType>(dtype);
|
||||||
|
return c10::elementSize(scalar_type);
|
||||||
|
}
|
||||||
|
|
||||||
AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor) {
|
AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor) {
|
||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
|
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
|
||||||
|
|
|
||||||
42
torch/csrc/inductor/aoti_torch/shim_mps.mm
Normal file
42
torch/csrc/inductor/aoti_torch/shim_mps.mm
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
#include <ATen/native/mps/MetalShaderLibrary.h>
|
||||||
|
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
|
||||||
|
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||||
|
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||||
|
#include <ATen/mps/MPSDevice.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace torch::aot_inductor;
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_malloc(
|
||||||
|
void** buffer,
|
||||||
|
size_t num_bytes) {
|
||||||
|
if (num_bytes == 0) {
|
||||||
|
*buffer = nullptr;
|
||||||
|
return AOTI_TORCH_SUCCESS;
|
||||||
|
}
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
id<MTLDevice> device = at::mps::MPSDevice::getInstance()->device();
|
||||||
|
TORCH_CHECK(device, "Failed to get MPS device");
|
||||||
|
id<MTLBuffer> metal_buffer = [device newBufferWithLength:num_bytes options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared];
|
||||||
|
TORCH_CHECK(metal_buffer, "Failed to allocate memory on MPS device");
|
||||||
|
*buffer = (void*)metal_buffer;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_free(
|
||||||
|
void* ptr) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto metal_buffer = (id<MTLBuffer>)ptr;
|
||||||
|
[metal_buffer release];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
AOTITorchError
|
||||||
|
aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto metal_buffer = (id<MTLBuffer>)buffer;
|
||||||
|
auto buffer_pointer = static_cast<uint8_t*>([metal_buffer contents]);
|
||||||
|
memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user