AOTI MPS Shim Implementation (#163865)

## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
This commit is contained in:
Manuel Candales 2025-10-09 16:06:36 +00:00 committed by PyTorch MergeBot
parent 3d1fa40ae1
commit aea57b3aa3
9 changed files with 372 additions and 66 deletions

View File

@ -116,6 +116,8 @@ class MetalShaderLibrary {
std::vector<std::string> getFunctionNames(); std::vector<std::string> getFunctionNames();
std::shared_ptr<MetalKernelFunction> getKernelFunction( std::shared_ptr<MetalKernelFunction> getKernelFunction(
const std::string& name); const std::string& name);
// Returns a raw pointer to the kernel function for use in C APIs
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
inline MTLComputePipelineState_t getPipelineStateForFunc( inline MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname) { const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first; return getLibraryPipelineState(getLibrary(), fname).first;
@ -164,6 +166,9 @@ class MetalShaderLibrary {
std::string, std::string,
std::pair<MTLComputePipelineState_t, MTLFunction_t>> std::pair<MTLComputePipelineState_t, MTLFunction_t>>
cplMap; cplMap;
// Cache for kernel functions returned by getCachedKernelFunctionPtr
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
kernelCache;
}; };
class DynamicMetalShaderLibrary : public MetalShaderLibrary { class DynamicMetalShaderLibrary : public MetalShaderLibrary {

View File

@ -917,6 +917,22 @@ std::shared_ptr<MetalKernelFunction> MetalShaderLibrary::getKernelFunction(const
return std::make_shared<MetalKernelFunction>(cpl, func); return std::make_shared<MetalKernelFunction>(cpl, func);
} }
MetalKernelFunction* MetalShaderLibrary::getCachedKernelFunctionPtr(const std::string& name) {
// Check if kernel is already cached
auto it = kernelCache.find(name);
if (it != kernelCache.end()) {
return it->second.get();
}
// Create new kernel function and cache it
auto [cpl, func] = getLibraryPipelineState(getLibrary(), name);
auto kernel = std::make_unique<MetalKernelFunction>(cpl, func);
MetalKernelFunction* raw_ptr = kernel.get();
kernelCache[name] = std::move(kernel);
return raw_ptr;
}
class BundledShaderLibary : public MetalShaderLibrary { class BundledShaderLibary : public MetalShaderLibrary {
public: public:
BundledShaderLibary() : MetalShaderLibrary("") {} BundledShaderLibary() : MetalShaderLibrary("") {}

View File

@ -202,7 +202,7 @@ class AOTInductorTestsTemplate:
AOTIRunnerUtil.compile, model, example_inputs AOTIRunnerUtil.compile, model, example_inputs
) )
if self.device == "mps": if self.device == "mps":
FileCheck().check("getKernelFunction(").run(code) FileCheck().check("aoti_torch_mps_get_kernel_function(").run(code)
elif self.device == GPU_TYPE: elif self.device == GPU_TYPE:
FileCheck().check("launchKernel(").run(code) FileCheck().check("launchKernel(").run(code)
if config.aot_inductor.embed_kernel_binary: if config.aot_inductor.embed_kernel_binary:
@ -2893,7 +2893,7 @@ class AOTInductorTestsTemplate:
if self.device == "mps": if self.device == "mps":
self.code_check_count( self.code_check_count(
model, example_inputs, '.getKernelFunction("generated_kernel")', 1 model, example_inputs, "aoti_torch_mps_get_kernel_function(", 1
) )
elif self.device == GPU_TYPE: elif self.device == GPU_TYPE:
self.code_check_count( self.code_check_count(

View File

@ -270,7 +270,7 @@ class MPSBasicTestsAOTI(TestCase):
ep = torch.export.export(model, example_inputs) ep = torch.export.export(model, example_inputs)
package_path = torch._export.aot_compile(ep.module(), example_inputs) package_path = torch._export.aot_compile(ep.module(), example_inputs)
target_str = 'mps_lib_0.getKernelFunction("generated_kernel")' target_str = "aoti_torch_mps_get_kernel_function("
target_count = 1 target_count = 1
with open(os.path.splitext(package_path)[0] + ".cpp") as cpp: with open(os.path.splitext(package_path)[0] + ".cpp") as cpp:

View File

@ -20,6 +20,7 @@ class CppWrapperMps(CppWrapperGpu):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._used_kernel_names: OrderedSet[str] = OrderedSet() self._used_kernel_names: OrderedSet[str] = OrderedSet()
self._lambda_counter: int = 0
@staticmethod @staticmethod
def create( def create(
@ -47,13 +48,16 @@ class CppWrapperMps(CppWrapperGpu):
""" """
Generates MPS kernel call code. It should look something like: Generates MPS kernel call code. It should look something like:
``` ```
get_mps_lib_0()->runCommandBlock([&] { auto mps_lib_0_lambda = [&](AOTIMetalKernelFunctionHandle handle) {
get_mps_lib_0()->startEncoding(); aoti_torch_mps_start_encoding(handle);
aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 0, buf0); aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 1, arg0_1); aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
... aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
get_mps_lib_0()->dispatch(9); aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
}); };
std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper = mps_lib_0_lambda;
aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper);
``` ```
""" """
device = device or V.graph.get_current_device_or_throw() device = device or V.graph.get_current_device_or_throw()
@ -78,13 +82,9 @@ class CppWrapperMps(CppWrapperGpu):
new_args = [] new_args = []
for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
if isinstance(arg_type, torch.dtype): if isinstance(arg_type, torch.dtype):
new_args.append( new_args.append(f"aoti_torch_mps_set_arg_tensor(handle, {idx}, {arg});")
f"aoti_torch_mps_set_arg_tensor(get_{kernel_name}_handle(), {idx}, {arg});"
)
elif arg_type in (int, sympy.core.symbol.Symbol): elif arg_type in (int, sympy.core.symbol.Symbol):
new_args.append( new_args.append(f"aoti_torch_mps_set_arg_int(handle, {idx}, {arg});")
f"aoti_torch_mps_set_arg_int(get_{kernel_name}_handle(), {idx}, {arg});"
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}" f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}"
@ -93,12 +93,85 @@ class CppWrapperMps(CppWrapperGpu):
threads, group_size = call_args[-2], call_args[-1] threads, group_size = call_args[-2], call_args[-1]
if threads is None: if threads is None:
raise NotImplementedError("No threads or group_size provided") raise NotImplementedError("No threads or group_size provided")
elif group_size is None:
new_args.append(f"get_{kernel_name}()->dispatch({threads});\n") # Check if threads is a single value or an array-like structure
threads_str = str(threads)
is_single_value = (
threads_str.startswith("{")
and threads_str.endswith("}")
and threads_str.count(",") == 0
) or not threads_str.startswith(("{", "["))
if is_single_value:
# Extract single value from braces if present
if threads_str.startswith("{") and threads_str.endswith("}"):
single_value = threads_str[1:-1].strip() # Remove braces
else:
single_value = threads_str
if group_size is None:
new_args.append(
f"aoti_torch_mps_dispatch_single(handle, {single_value});"
)
else:
# Extract group size value if it's also in braces
group_size_str = str(group_size)
if group_size_str.startswith("{") and group_size_str.endswith("}"):
group_size_value = group_size_str[1:-1].strip()
else:
group_size_value = group_size_str
new_args.append(
f"aoti_torch_mps_dispatch_single_with_group_size(handle, {single_value}, {group_size_value});"
)
else: else:
new_args.append( # Handle array case - need to convert initializer list to array
f"get_{kernel_name}()->dispatch({threads}, {group_size});\n" # Use kernel name to make variable names unique
) threads_var = f"{kernel_name}_threads_array"
group_size_var = f"{kernel_name}_group_size_array"
# Extract array size from the initializer list string
def get_array_size(array_str: str) -> int:
# Remove braces and whitespace
content = array_str.strip()
if content.startswith("{") and content.endswith("}"):
content = content[1:-1].strip()
if not content: # Empty array
return 0
# Count elements by counting commas, accounting for nested structures
depth = 0
comma_count = 0
for char in content:
if char in "({[<":
depth += 1
elif char in ")}]>":
depth -= 1
elif char == "," and depth == 0:
comma_count += 1
return comma_count + 1 # Number of elements = commas + 1
threads_size = get_array_size(threads_str)
if group_size is None:
new_args.append("{")
new_args.append(f" uint64_t {threads_var}[] = {threads};")
new_args.append(
f" aoti_torch_mps_dispatch_array(handle, {threads_var}, {threads_size});"
)
new_args.append("}")
else:
group_size_str = str(group_size)
group_size_size = get_array_size(group_size_str)
new_args.append("{")
new_args.append(f" uint64_t {threads_var}[] = {threads};")
new_args.append(f" uint64_t {group_size_var}[] = {group_size};")
dispatch_args = f"handle, {threads_var}, {threads_size}, {group_size_var}, {group_size_size}"
new_args.append(
f" aoti_torch_mps_dispatch_array_with_group_size({dispatch_args});"
)
new_args.append("}")
# debug printer related logic for cpp kernel type. # debug printer related logic for cpp kernel type.
debug_printer_manager = V.graph.wrapper_code.debug_printer debug_printer_manager = V.graph.wrapper_code.debug_printer
@ -113,14 +186,34 @@ class CppWrapperMps(CppWrapperGpu):
self.write_mps_kernel_call(kernel_name, new_args) self.write_mps_kernel_call(kernel_name, new_args)
def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None:
# Initialization of the kernel function and kernel function handle # Generate unique variable names to avoid duplicate declarations
# variables have already been done at the beginning, which was # when the same MPS lib is used multiple times
# codegen-ed in `codegen_mps_func_init` unique_suffix = self._lambda_counter
self.writeline(f"get_{name}()->runCommandBlock([&] {{") self._lambda_counter += 1
self.writeline(f" get_{name}()->startEncoding();")
lambda_name = f"{name}_lambda_{unique_suffix}"
wrapper_name = f"{name}_func_wrapper_{unique_suffix}"
# Generate the function call code (in current location)
# Create lambda that captures by reference and pass its pointer through void*
self.writeline(
f"auto {lambda_name} = [&](AOTIMetalKernelFunctionHandle handle) {{"
)
self.writeline(" aoti_torch_mps_start_encoding(handle);")
# Output call args directly since we're capturing by reference
for call_arg in call_args: for call_arg in call_args:
self.writeline(f" {call_arg}") self.writeline(f" {call_arg}")
self.writeline("});") self.writeline("};")
self.writeline("")
# Pass lambda pointer through void*
self.writeline(
f"std::function<void(AOTIMetalKernelFunctionHandle)> {wrapper_name} = {lambda_name};"
)
self.writeline(
f"aoti_torch_mps_run_command_block(get_{name}_handle(), aoti_torch_mps_shared_callback, &{wrapper_name});"
)
@staticmethod @staticmethod
def get_device_include_path(device: str) -> str: def get_device_include_path(device: str) -> str:
@ -132,49 +225,77 @@ class CppWrapperMps(CppWrapperGpu):
def codegen_additional_funcs(self) -> None: def codegen_additional_funcs(self) -> None:
""" """
We want to codegen the mps kernel function variable initializations Generate thread-safe lazy singleton pattern for MPS shader libraries with RAII cleanup.
ahead of time. This is so that if we reuse kernels within subgraphs, we
don't need to worry about the scope in which we're initializing the
variables. Instead we will just initialize the variables all at the top
level.
The kernel function variable initializations should look something like: The generated code will look like:
``` ```
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
return func;
}
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() { AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get()); static auto kernel_handle = []() {
return handle; AOTIMetalShaderLibraryHandle lib_handle = nullptr;
AOTIMetalKernelFunctionHandle kern_handle = nullptr;
aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
// RAII wrapper with custom deleter
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {
if (h) aoti_torch_mps_delete_shader_library(h);
};
using LibDeleter = decltype(lib_deleter);
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
// Return pair of kernel handle and library smart pointer for cleanup
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
}();
return kernel_handle.first;
} }
``` ```
""" """
# Add shimified handles and functions
shader_libraries: OrderedSet[str] = OrderedSet()
for line in self.lines: for line in self.lines:
if not isinstance(line, KernelCallLine): if not isinstance(line, KernelCallLine):
continue continue
if line.device.type != "mps": if line.device.type != "mps":
continue continue
# Only add handle definition once # Extract library name from kernel name (e.g., "mps_lib_0" from kernel calls)
if line.kernel_name not in self._used_kernel_names: if line.kernel_name not in self._used_kernel_names:
self._used_kernel_names.add(line.kernel_name) self._used_kernel_names.add(line.kernel_name)
shader_libraries.add(line.kernel_name)
self.prefix.writeline( # NOTE: For shimified version, we expect the shader source constant to be generated
f"const std::shared_ptr<at::native::mps::MetalKernelFunction> get_{line.kernel_name}() {{" # by the existing MPS shader generation process, but instead of instantiating the
) # DynamicMetalShaderLibrary directly, we'll use our shim functions.
self.prefix.writeline( # The existing codegen should produce something like:
f' static const auto func = {line.kernel_name}.getKernelFunction("generated_kernel");' # const char* mps_lib_0_source = R"MTL(...shader_source...)MTL";
) # instead of:
self.prefix.writeline(" return func;") # at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...shader_source...)MTL");
self.prefix.writeline("}")
self.prefix.writeline( # Generate thread-safe lazy singleton with RAII for each library
f"AOTIMetalKernelFunctionHandle get_{line.kernel_name}_handle() {{" for lib_name in shader_libraries:
) self.prefix.splice(f"""
self.prefix.writeline( AOTIMetalKernelFunctionHandle get_{lib_name}_handle() {{
f" static const auto handle = AOTIMetalKernelFunctionHandle(get_{line.kernel_name}().get());" static auto kernel_handle = []() {{
) AOTIMetalShaderLibraryHandle lib_handle = nullptr;
self.prefix.writeline(" return handle;") AOTIMetalKernelFunctionHandle kern_handle = nullptr;
self.prefix.writeline("}")
aoti_torch_mps_create_shader_library({lib_name}_source, &lib_handle);
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
// RAII wrapper with custom deleter
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
if (h) aoti_torch_mps_delete_shader_library(h);
}};
using LibDeleter = decltype(lib_deleter);
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
// Return pair of kernel handle and library smart pointer for cleanup
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
}}();
return kernel_handle.first;
}}
""")

View File

@ -1058,10 +1058,8 @@ class MetalScheduling(SIMDScheduling):
wrapper.src_to_kernel[src_code] = kernel_name wrapper.src_to_kernel[src_code] = kernel_name
if V.graph.cpp_wrapper: if V.graph.cpp_wrapper:
src_code = ( # For shimified version, generate source constant instead of direct instantiation
f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}" src_code = f"const char* {mps_lib_name}_source = " + src_code
+ src_code
)
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
metadata_comment = f"{origins}\n{detailed_origins}" metadata_comment = f"{origins}\n{detailed_origins}"

View File

@ -3,12 +3,32 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h>
struct AOTIMetalKernelFunctionOpaque;
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
struct AOTIMetalShaderLibraryOpaque;
using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*;
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
struct AOTIMetalKernelFunctionOpaque; // MetalShaderLibrary functions
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*; AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_create_shader_library(
const char* metal_shader_source,
AOTIMetalShaderLibraryHandle* library_handle);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_delete_shader_library(
AOTIMetalShaderLibraryHandle library_handle);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_get_kernel_function(
AOTIMetalShaderLibraryHandle library_handle,
const char* kernel_name,
AOTIMetalKernelFunctionHandle* function_handle);
// MetalKernelFunction functions
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_mps_start_encoding(AOTIMetalKernelFunctionHandle func);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_tensor( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_tensor(
AOTIMetalKernelFunctionHandle func, AOTIMetalKernelFunctionHandle func,
@ -20,6 +40,27 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_int(
unsigned idx, unsigned idx,
int64_t val); int64_t val);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_single(
AOTIMetalKernelFunctionHandle func,
uint64_t length);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
AOTIMetalKernelFunctionHandle func,
uint64_t length,
uint64_t group_size);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_array(
AOTIMetalKernelFunctionHandle func,
const uint64_t* length,
size_t length_size);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
AOTIMetalKernelFunctionHandle func,
const uint64_t* length,
size_t length_size,
const uint64_t* group_size,
size_t group_size_size);
AOTI_TORCH_EXPORT AOTITorchError AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_mps_malloc(void** buffer, size_t num_bytes); aoti_torch_mps_malloc(void** buffer, size_t num_bytes);
@ -39,6 +80,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_copy_buffer(
size_t src_offset, size_t src_offset,
size_t dst_offset); size_t dst_offset);
// C callback function type for command block execution
typedef void (*aoti_torch_mps_command_block_callback_t)(
AOTIMetalKernelFunctionHandle func,
void* user_data);
// Shared callback function for std::function trampoline
AOTI_TORCH_EXPORT void aoti_torch_mps_shared_callback(
AOTIMetalKernelFunctionHandle func,
void* user_data);
// Pure C version using function pointer and user data for trampoline pattern
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_run_command_block(
AOTIMetalKernelFunctionHandle func,
aoti_torch_mps_command_block_callback_t callback,
void* user_data);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif

View File

@ -27,3 +27,116 @@ AOTITorchError aoti_torch_mps_set_arg_int(
func->setArg(idx, val); func->setArg(idx, val);
}); });
} }
AOTITorchError aoti_torch_mps_create_shader_library(
const char* metal_shader_source,
AOTIMetalShaderLibraryHandle* library_handle) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* library = new at::native::mps::DynamicMetalShaderLibrary(
std::string(metal_shader_source));
*library_handle = reinterpret_cast<AOTIMetalShaderLibraryHandle>(library);
});
}
AOTITorchError aoti_torch_mps_delete_shader_library(
AOTIMetalShaderLibraryHandle library_handle) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* library =
reinterpret_cast<at::native::mps::MetalShaderLibrary*>(library_handle);
delete library;
});
}
AOTITorchError aoti_torch_mps_get_kernel_function(
AOTIMetalShaderLibraryHandle library_handle,
const char* kernel_name,
AOTIMetalKernelFunctionHandle* function_handle) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* library =
reinterpret_cast<at::native::mps::MetalShaderLibrary*>(library_handle);
auto* function =
library->getCachedKernelFunctionPtr(std::string(kernel_name));
*function_handle =
reinterpret_cast<AOTIMetalKernelFunctionHandle>(function);
});
}
AOTITorchError aoti_torch_mps_start_encoding(
AOTIMetalKernelFunctionHandle func) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
function_ptr->startEncoding();
});
}
AOTITorchError aoti_torch_mps_dispatch_single(
AOTIMetalKernelFunctionHandle func,
uint64_t length) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
function_ptr->dispatch(length);
});
}
AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
AOTIMetalKernelFunctionHandle func,
uint64_t length,
uint64_t group_size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
function_ptr->dispatch(length, group_size);
});
}
AOTITorchError aoti_torch_mps_dispatch_array(
AOTIMetalKernelFunctionHandle func,
const uint64_t* length,
size_t length_size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
c10::ArrayRef<uint64_t> length_ref(length, length_size);
function_ptr->dispatch(length_ref);
});
}
AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
AOTIMetalKernelFunctionHandle func,
const uint64_t* length,
size_t length_size,
const uint64_t* group_size,
size_t group_size_size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
c10::ArrayRef<uint64_t> length_ref(length, length_size);
c10::ArrayRef<uint64_t> group_size_ref(group_size, group_size_size);
function_ptr->dispatch(length_ref, group_size_ref);
});
}
// Shared callback function for std::function trampoline
void aoti_torch_mps_shared_callback(
AOTIMetalKernelFunctionHandle func,
void* user_data) {
auto* function_wrapper =
static_cast<std::function<void(AOTIMetalKernelFunctionHandle)>*>(
user_data);
(*function_wrapper)(func);
}
// Pure C version using function pointer and user data for trampoline pattern
AOTITorchError aoti_torch_mps_run_command_block(
AOTIMetalKernelFunctionHandle func,
aoti_torch_mps_command_block_callback_t callback,
void* user_data) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
function_ptr->runCommandBlock(
[callback, func, user_data]() { callback(func, user_data); });
});
}

View File

@ -1,4 +1,3 @@
#include <ATen/native/mps/MetalShaderLibrary.h>
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h> #include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
#include <torch/csrc/inductor/aoti_torch/utils.h> #include <torch/csrc/inductor/aoti_torch/utils.h>
#include <ATen/mps/MPSAllocatorInterface.h> #include <ATen/mps/MPSAllocatorInterface.h>
@ -6,7 +5,6 @@
#include <ATen/mps/MPSStream.h> #include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSProfiler.h> #include <ATen/mps/MPSProfiler.h>
using namespace torch::aot_inductor; using namespace torch::aot_inductor;
AOTITorchError aoti_torch_mps_malloc( AOTITorchError aoti_torch_mps_malloc(
@ -33,7 +31,6 @@ AOTITorchError aoti_torch_mps_free(
}); });
} }
AOTITorchError AOTITorchError
aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) { aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
@ -46,7 +43,6 @@ aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, s
AOTITorchError AOTITorchError
aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) { aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto src_mtl_buffer = (id<MTLBuffer>)src_buffer; auto src_mtl_buffer = (id<MTLBuffer>)src_buffer;
auto dst_mtl_buffer = (id<MTLBuffer>)dst_buffer; auto dst_mtl_buffer = (id<MTLBuffer>)dst_buffer;