mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
AOTI MPS Shim Implementation (#163865)
## MPS Shim API
* Updated MPS shimification API with handles and function declarations:
* `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
* Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
* Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc
## MPS Shader Codegen
* Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
* **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
* **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
* Updated kernel call generation to use shimified functions:
* Generates calls to shimified API instead of direct libtorch calls
## Before vs After Comparison
### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
...
)MTL");
```
### Section 2: Getter Functions & RAII Management
**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
return func;
}
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
return handle;
}
```
**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
static auto kernel_handle = []() {
AOTIMetalShaderLibraryHandle lib_handle = nullptr;
AOTIMetalKernelFunctionHandle kern_handle = nullptr;
aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
// RAII wrapper with custom deleter
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
if (h) aoti_torch_mps_delete_shader_library(h);
}};
using LibDeleter = decltype(lib_deleter);
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
// Return pair of kernel handle and library smart pointer for cleanup
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
}();
return kernel_handle.first;
}
```
### Section 3: Runtime Execution
**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {
...
get_mps_lib_0()->runCommandBlock([&] {
get_mps_lib_0()->startEncoding();
aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});
});
...
} // AOTInductorModel::run_impl
```
**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {
...
auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
aoti_torch_mps_start_encoding(handle);
aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
};
std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);
...
} // AOTInductorModel::run_impl
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
This commit is contained in:
parent
3d1fa40ae1
commit
aea57b3aa3
|
|
@ -116,6 +116,8 @@ class MetalShaderLibrary {
|
||||||
std::vector<std::string> getFunctionNames();
|
std::vector<std::string> getFunctionNames();
|
||||||
std::shared_ptr<MetalKernelFunction> getKernelFunction(
|
std::shared_ptr<MetalKernelFunction> getKernelFunction(
|
||||||
const std::string& name);
|
const std::string& name);
|
||||||
|
// Returns a raw pointer to the kernel function for use in C APIs
|
||||||
|
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
|
||||||
inline MTLComputePipelineState_t getPipelineStateForFunc(
|
inline MTLComputePipelineState_t getPipelineStateForFunc(
|
||||||
const std::string& fname) {
|
const std::string& fname) {
|
||||||
return getLibraryPipelineState(getLibrary(), fname).first;
|
return getLibraryPipelineState(getLibrary(), fname).first;
|
||||||
|
|
@ -164,6 +166,9 @@ class MetalShaderLibrary {
|
||||||
std::string,
|
std::string,
|
||||||
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
|
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
|
||||||
cplMap;
|
cplMap;
|
||||||
|
// Cache for kernel functions returned by getCachedKernelFunctionPtr
|
||||||
|
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
|
||||||
|
kernelCache;
|
||||||
};
|
};
|
||||||
|
|
||||||
class DynamicMetalShaderLibrary : public MetalShaderLibrary {
|
class DynamicMetalShaderLibrary : public MetalShaderLibrary {
|
||||||
|
|
|
||||||
|
|
@ -917,6 +917,22 @@ std::shared_ptr<MetalKernelFunction> MetalShaderLibrary::getKernelFunction(const
|
||||||
return std::make_shared<MetalKernelFunction>(cpl, func);
|
return std::make_shared<MetalKernelFunction>(cpl, func);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MetalKernelFunction* MetalShaderLibrary::getCachedKernelFunctionPtr(const std::string& name) {
|
||||||
|
// Check if kernel is already cached
|
||||||
|
auto it = kernelCache.find(name);
|
||||||
|
if (it != kernelCache.end()) {
|
||||||
|
return it->second.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new kernel function and cache it
|
||||||
|
auto [cpl, func] = getLibraryPipelineState(getLibrary(), name);
|
||||||
|
auto kernel = std::make_unique<MetalKernelFunction>(cpl, func);
|
||||||
|
MetalKernelFunction* raw_ptr = kernel.get();
|
||||||
|
kernelCache[name] = std::move(kernel);
|
||||||
|
|
||||||
|
return raw_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
class BundledShaderLibary : public MetalShaderLibrary {
|
class BundledShaderLibary : public MetalShaderLibrary {
|
||||||
public:
|
public:
|
||||||
BundledShaderLibary() : MetalShaderLibrary("") {}
|
BundledShaderLibary() : MetalShaderLibrary("") {}
|
||||||
|
|
|
||||||
|
|
@ -202,7 +202,7 @@ class AOTInductorTestsTemplate:
|
||||||
AOTIRunnerUtil.compile, model, example_inputs
|
AOTIRunnerUtil.compile, model, example_inputs
|
||||||
)
|
)
|
||||||
if self.device == "mps":
|
if self.device == "mps":
|
||||||
FileCheck().check("getKernelFunction(").run(code)
|
FileCheck().check("aoti_torch_mps_get_kernel_function(").run(code)
|
||||||
elif self.device == GPU_TYPE:
|
elif self.device == GPU_TYPE:
|
||||||
FileCheck().check("launchKernel(").run(code)
|
FileCheck().check("launchKernel(").run(code)
|
||||||
if config.aot_inductor.embed_kernel_binary:
|
if config.aot_inductor.embed_kernel_binary:
|
||||||
|
|
@ -2893,7 +2893,7 @@ class AOTInductorTestsTemplate:
|
||||||
|
|
||||||
if self.device == "mps":
|
if self.device == "mps":
|
||||||
self.code_check_count(
|
self.code_check_count(
|
||||||
model, example_inputs, '.getKernelFunction("generated_kernel")', 1
|
model, example_inputs, "aoti_torch_mps_get_kernel_function(", 1
|
||||||
)
|
)
|
||||||
elif self.device == GPU_TYPE:
|
elif self.device == GPU_TYPE:
|
||||||
self.code_check_count(
|
self.code_check_count(
|
||||||
|
|
|
||||||
|
|
@ -270,7 +270,7 @@ class MPSBasicTestsAOTI(TestCase):
|
||||||
ep = torch.export.export(model, example_inputs)
|
ep = torch.export.export(model, example_inputs)
|
||||||
package_path = torch._export.aot_compile(ep.module(), example_inputs)
|
package_path = torch._export.aot_compile(ep.module(), example_inputs)
|
||||||
|
|
||||||
target_str = 'mps_lib_0.getKernelFunction("generated_kernel")'
|
target_str = "aoti_torch_mps_get_kernel_function("
|
||||||
target_count = 1
|
target_count = 1
|
||||||
|
|
||||||
with open(os.path.splitext(package_path)[0] + ".cpp") as cpp:
|
with open(os.path.splitext(package_path)[0] + ".cpp") as cpp:
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class CppWrapperMps(CppWrapperGpu):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._used_kernel_names: OrderedSet[str] = OrderedSet()
|
self._used_kernel_names: OrderedSet[str] = OrderedSet()
|
||||||
|
self._lambda_counter: int = 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
|
|
@ -47,13 +48,16 @@ class CppWrapperMps(CppWrapperGpu):
|
||||||
"""
|
"""
|
||||||
Generates MPS kernel call code. It should look something like:
|
Generates MPS kernel call code. It should look something like:
|
||||||
```
|
```
|
||||||
get_mps_lib_0()->runCommandBlock([&] {
|
auto mps_lib_0_lambda = [&](AOTIMetalKernelFunctionHandle handle) {
|
||||||
get_mps_lib_0()->startEncoding();
|
aoti_torch_mps_start_encoding(handle);
|
||||||
aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 0, buf0);
|
aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
|
||||||
aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 1, arg0_1);
|
aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
|
||||||
...
|
aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
|
||||||
get_mps_lib_0()->dispatch(9);
|
aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
|
||||||
});
|
};
|
||||||
|
|
||||||
|
std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper = mps_lib_0_lambda;
|
||||||
|
aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper);
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
device = device or V.graph.get_current_device_or_throw()
|
device = device or V.graph.get_current_device_or_throw()
|
||||||
|
|
@ -78,13 +82,9 @@ class CppWrapperMps(CppWrapperGpu):
|
||||||
new_args = []
|
new_args = []
|
||||||
for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
|
for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
|
||||||
if isinstance(arg_type, torch.dtype):
|
if isinstance(arg_type, torch.dtype):
|
||||||
new_args.append(
|
new_args.append(f"aoti_torch_mps_set_arg_tensor(handle, {idx}, {arg});")
|
||||||
f"aoti_torch_mps_set_arg_tensor(get_{kernel_name}_handle(), {idx}, {arg});"
|
|
||||||
)
|
|
||||||
elif arg_type in (int, sympy.core.symbol.Symbol):
|
elif arg_type in (int, sympy.core.symbol.Symbol):
|
||||||
new_args.append(
|
new_args.append(f"aoti_torch_mps_set_arg_int(handle, {idx}, {arg});")
|
||||||
f"aoti_torch_mps_set_arg_int(get_{kernel_name}_handle(), {idx}, {arg});"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}"
|
f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}"
|
||||||
|
|
@ -93,12 +93,85 @@ class CppWrapperMps(CppWrapperGpu):
|
||||||
threads, group_size = call_args[-2], call_args[-1]
|
threads, group_size = call_args[-2], call_args[-1]
|
||||||
if threads is None:
|
if threads is None:
|
||||||
raise NotImplementedError("No threads or group_size provided")
|
raise NotImplementedError("No threads or group_size provided")
|
||||||
elif group_size is None:
|
|
||||||
new_args.append(f"get_{kernel_name}()->dispatch({threads});\n")
|
# Check if threads is a single value or an array-like structure
|
||||||
|
threads_str = str(threads)
|
||||||
|
is_single_value = (
|
||||||
|
threads_str.startswith("{")
|
||||||
|
and threads_str.endswith("}")
|
||||||
|
and threads_str.count(",") == 0
|
||||||
|
) or not threads_str.startswith(("{", "["))
|
||||||
|
|
||||||
|
if is_single_value:
|
||||||
|
# Extract single value from braces if present
|
||||||
|
if threads_str.startswith("{") and threads_str.endswith("}"):
|
||||||
|
single_value = threads_str[1:-1].strip() # Remove braces
|
||||||
|
else:
|
||||||
|
single_value = threads_str
|
||||||
|
|
||||||
|
if group_size is None:
|
||||||
|
new_args.append(
|
||||||
|
f"aoti_torch_mps_dispatch_single(handle, {single_value});"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Extract group size value if it's also in braces
|
||||||
|
group_size_str = str(group_size)
|
||||||
|
if group_size_str.startswith("{") and group_size_str.endswith("}"):
|
||||||
|
group_size_value = group_size_str[1:-1].strip()
|
||||||
|
else:
|
||||||
|
group_size_value = group_size_str
|
||||||
|
new_args.append(
|
||||||
|
f"aoti_torch_mps_dispatch_single_with_group_size(handle, {single_value}, {group_size_value});"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new_args.append(
|
# Handle array case - need to convert initializer list to array
|
||||||
f"get_{kernel_name}()->dispatch({threads}, {group_size});\n"
|
# Use kernel name to make variable names unique
|
||||||
)
|
threads_var = f"{kernel_name}_threads_array"
|
||||||
|
group_size_var = f"{kernel_name}_group_size_array"
|
||||||
|
|
||||||
|
# Extract array size from the initializer list string
|
||||||
|
def get_array_size(array_str: str) -> int:
|
||||||
|
# Remove braces and whitespace
|
||||||
|
content = array_str.strip()
|
||||||
|
if content.startswith("{") and content.endswith("}"):
|
||||||
|
content = content[1:-1].strip()
|
||||||
|
|
||||||
|
if not content: # Empty array
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Count elements by counting commas, accounting for nested structures
|
||||||
|
depth = 0
|
||||||
|
comma_count = 0
|
||||||
|
for char in content:
|
||||||
|
if char in "({[<":
|
||||||
|
depth += 1
|
||||||
|
elif char in ")}]>":
|
||||||
|
depth -= 1
|
||||||
|
elif char == "," and depth == 0:
|
||||||
|
comma_count += 1
|
||||||
|
|
||||||
|
return comma_count + 1 # Number of elements = commas + 1
|
||||||
|
|
||||||
|
threads_size = get_array_size(threads_str)
|
||||||
|
|
||||||
|
if group_size is None:
|
||||||
|
new_args.append("{")
|
||||||
|
new_args.append(f" uint64_t {threads_var}[] = {threads};")
|
||||||
|
new_args.append(
|
||||||
|
f" aoti_torch_mps_dispatch_array(handle, {threads_var}, {threads_size});"
|
||||||
|
)
|
||||||
|
new_args.append("}")
|
||||||
|
else:
|
||||||
|
group_size_str = str(group_size)
|
||||||
|
group_size_size = get_array_size(group_size_str)
|
||||||
|
new_args.append("{")
|
||||||
|
new_args.append(f" uint64_t {threads_var}[] = {threads};")
|
||||||
|
new_args.append(f" uint64_t {group_size_var}[] = {group_size};")
|
||||||
|
dispatch_args = f"handle, {threads_var}, {threads_size}, {group_size_var}, {group_size_size}"
|
||||||
|
new_args.append(
|
||||||
|
f" aoti_torch_mps_dispatch_array_with_group_size({dispatch_args});"
|
||||||
|
)
|
||||||
|
new_args.append("}")
|
||||||
|
|
||||||
# debug printer related logic for cpp kernel type.
|
# debug printer related logic for cpp kernel type.
|
||||||
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
||||||
|
|
@ -113,14 +186,34 @@ class CppWrapperMps(CppWrapperGpu):
|
||||||
self.write_mps_kernel_call(kernel_name, new_args)
|
self.write_mps_kernel_call(kernel_name, new_args)
|
||||||
|
|
||||||
def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None:
|
def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None:
|
||||||
# Initialization of the kernel function and kernel function handle
|
# Generate unique variable names to avoid duplicate declarations
|
||||||
# variables have already been done at the beginning, which was
|
# when the same MPS lib is used multiple times
|
||||||
# codegen-ed in `codegen_mps_func_init`
|
unique_suffix = self._lambda_counter
|
||||||
self.writeline(f"get_{name}()->runCommandBlock([&] {{")
|
self._lambda_counter += 1
|
||||||
self.writeline(f" get_{name}()->startEncoding();")
|
|
||||||
|
lambda_name = f"{name}_lambda_{unique_suffix}"
|
||||||
|
wrapper_name = f"{name}_func_wrapper_{unique_suffix}"
|
||||||
|
|
||||||
|
# Generate the function call code (in current location)
|
||||||
|
# Create lambda that captures by reference and pass its pointer through void*
|
||||||
|
self.writeline(
|
||||||
|
f"auto {lambda_name} = [&](AOTIMetalKernelFunctionHandle handle) {{"
|
||||||
|
)
|
||||||
|
self.writeline(" aoti_torch_mps_start_encoding(handle);")
|
||||||
|
|
||||||
|
# Output call args directly since we're capturing by reference
|
||||||
for call_arg in call_args:
|
for call_arg in call_args:
|
||||||
self.writeline(f" {call_arg}")
|
self.writeline(f" {call_arg}")
|
||||||
self.writeline("});")
|
self.writeline("};")
|
||||||
|
self.writeline("")
|
||||||
|
|
||||||
|
# Pass lambda pointer through void*
|
||||||
|
self.writeline(
|
||||||
|
f"std::function<void(AOTIMetalKernelFunctionHandle)> {wrapper_name} = {lambda_name};"
|
||||||
|
)
|
||||||
|
self.writeline(
|
||||||
|
f"aoti_torch_mps_run_command_block(get_{name}_handle(), aoti_torch_mps_shared_callback, &{wrapper_name});"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_include_path(device: str) -> str:
|
def get_device_include_path(device: str) -> str:
|
||||||
|
|
@ -132,49 +225,77 @@ class CppWrapperMps(CppWrapperGpu):
|
||||||
|
|
||||||
def codegen_additional_funcs(self) -> None:
|
def codegen_additional_funcs(self) -> None:
|
||||||
"""
|
"""
|
||||||
We want to codegen the mps kernel function variable initializations
|
Generate thread-safe lazy singleton pattern for MPS shader libraries with RAII cleanup.
|
||||||
ahead of time. This is so that if we reuse kernels within subgraphs, we
|
|
||||||
don't need to worry about the scope in which we're initializing the
|
|
||||||
variables. Instead we will just initialize the variables all at the top
|
|
||||||
level.
|
|
||||||
|
|
||||||
The kernel function variable initializations should look something like:
|
The generated code will look like:
|
||||||
```
|
```
|
||||||
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
|
|
||||||
static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
|
|
||||||
return func;
|
|
||||||
}
|
|
||||||
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
|
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
|
||||||
static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
|
static auto kernel_handle = []() {
|
||||||
return handle;
|
AOTIMetalShaderLibraryHandle lib_handle = nullptr;
|
||||||
|
AOTIMetalKernelFunctionHandle kern_handle = nullptr;
|
||||||
|
|
||||||
|
aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
|
||||||
|
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
|
||||||
|
|
||||||
|
// RAII wrapper with custom deleter
|
||||||
|
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {
|
||||||
|
if (h) aoti_torch_mps_delete_shader_library(h);
|
||||||
|
};
|
||||||
|
|
||||||
|
using LibDeleter = decltype(lib_deleter);
|
||||||
|
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
|
||||||
|
|
||||||
|
// Return pair of kernel handle and library smart pointer for cleanup
|
||||||
|
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
|
||||||
|
}();
|
||||||
|
return kernel_handle.first;
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Add shimified handles and functions
|
||||||
|
shader_libraries: OrderedSet[str] = OrderedSet()
|
||||||
for line in self.lines:
|
for line in self.lines:
|
||||||
if not isinstance(line, KernelCallLine):
|
if not isinstance(line, KernelCallLine):
|
||||||
continue
|
continue
|
||||||
if line.device.type != "mps":
|
if line.device.type != "mps":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Only add handle definition once
|
# Extract library name from kernel name (e.g., "mps_lib_0" from kernel calls)
|
||||||
if line.kernel_name not in self._used_kernel_names:
|
if line.kernel_name not in self._used_kernel_names:
|
||||||
self._used_kernel_names.add(line.kernel_name)
|
self._used_kernel_names.add(line.kernel_name)
|
||||||
|
shader_libraries.add(line.kernel_name)
|
||||||
|
|
||||||
self.prefix.writeline(
|
# NOTE: For shimified version, we expect the shader source constant to be generated
|
||||||
f"const std::shared_ptr<at::native::mps::MetalKernelFunction> get_{line.kernel_name}() {{"
|
# by the existing MPS shader generation process, but instead of instantiating the
|
||||||
)
|
# DynamicMetalShaderLibrary directly, we'll use our shim functions.
|
||||||
self.prefix.writeline(
|
# The existing codegen should produce something like:
|
||||||
f' static const auto func = {line.kernel_name}.getKernelFunction("generated_kernel");'
|
# const char* mps_lib_0_source = R"MTL(...shader_source...)MTL";
|
||||||
)
|
# instead of:
|
||||||
self.prefix.writeline(" return func;")
|
# at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...shader_source...)MTL");
|
||||||
self.prefix.writeline("}")
|
|
||||||
|
|
||||||
self.prefix.writeline(
|
# Generate thread-safe lazy singleton with RAII for each library
|
||||||
f"AOTIMetalKernelFunctionHandle get_{line.kernel_name}_handle() {{"
|
for lib_name in shader_libraries:
|
||||||
)
|
self.prefix.splice(f"""
|
||||||
self.prefix.writeline(
|
AOTIMetalKernelFunctionHandle get_{lib_name}_handle() {{
|
||||||
f" static const auto handle = AOTIMetalKernelFunctionHandle(get_{line.kernel_name}().get());"
|
static auto kernel_handle = []() {{
|
||||||
)
|
AOTIMetalShaderLibraryHandle lib_handle = nullptr;
|
||||||
self.prefix.writeline(" return handle;")
|
AOTIMetalKernelFunctionHandle kern_handle = nullptr;
|
||||||
self.prefix.writeline("}")
|
|
||||||
|
aoti_torch_mps_create_shader_library({lib_name}_source, &lib_handle);
|
||||||
|
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
|
||||||
|
|
||||||
|
// RAII wrapper with custom deleter
|
||||||
|
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
|
||||||
|
if (h) aoti_torch_mps_delete_shader_library(h);
|
||||||
|
}};
|
||||||
|
|
||||||
|
using LibDeleter = decltype(lib_deleter);
|
||||||
|
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
|
||||||
|
|
||||||
|
// Return pair of kernel handle and library smart pointer for cleanup
|
||||||
|
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
|
||||||
|
}}();
|
||||||
|
return kernel_handle.first;
|
||||||
|
}}
|
||||||
|
""")
|
||||||
|
|
|
||||||
|
|
@ -1058,10 +1058,8 @@ class MetalScheduling(SIMDScheduling):
|
||||||
wrapper.src_to_kernel[src_code] = kernel_name
|
wrapper.src_to_kernel[src_code] = kernel_name
|
||||||
|
|
||||||
if V.graph.cpp_wrapper:
|
if V.graph.cpp_wrapper:
|
||||||
src_code = (
|
# For shimified version, generate source constant instead of direct instantiation
|
||||||
f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}"
|
src_code = f"const char* {mps_lib_name}_source = " + src_code
|
||||||
+ src_code
|
|
||||||
)
|
|
||||||
|
|
||||||
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
||||||
metadata_comment = f"{origins}\n{detailed_origins}"
|
metadata_comment = f"{origins}\n{detailed_origins}"
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,32 @@
|
||||||
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||||
|
|
||||||
|
struct AOTIMetalKernelFunctionOpaque;
|
||||||
|
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
|
||||||
|
|
||||||
|
struct AOTIMetalShaderLibraryOpaque;
|
||||||
|
using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
struct AOTIMetalKernelFunctionOpaque;
|
// MetalShaderLibrary functions
|
||||||
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_create_shader_library(
|
||||||
|
const char* metal_shader_source,
|
||||||
|
AOTIMetalShaderLibraryHandle* library_handle);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_delete_shader_library(
|
||||||
|
AOTIMetalShaderLibraryHandle library_handle);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_get_kernel_function(
|
||||||
|
AOTIMetalShaderLibraryHandle library_handle,
|
||||||
|
const char* kernel_name,
|
||||||
|
AOTIMetalKernelFunctionHandle* function_handle);
|
||||||
|
|
||||||
|
// MetalKernelFunction functions
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError
|
||||||
|
aoti_torch_mps_start_encoding(AOTIMetalKernelFunctionHandle func);
|
||||||
|
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_tensor(
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_tensor(
|
||||||
AOTIMetalKernelFunctionHandle func,
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
|
@ -20,6 +40,27 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_int(
|
||||||
unsigned idx,
|
unsigned idx,
|
||||||
int64_t val);
|
int64_t val);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_single(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
uint64_t length);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
uint64_t length,
|
||||||
|
uint64_t group_size);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_array(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
const uint64_t* length,
|
||||||
|
size_t length_size);
|
||||||
|
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
const uint64_t* length,
|
||||||
|
size_t length_size,
|
||||||
|
const uint64_t* group_size,
|
||||||
|
size_t group_size_size);
|
||||||
|
|
||||||
AOTI_TORCH_EXPORT AOTITorchError
|
AOTI_TORCH_EXPORT AOTITorchError
|
||||||
aoti_torch_mps_malloc(void** buffer, size_t num_bytes);
|
aoti_torch_mps_malloc(void** buffer, size_t num_bytes);
|
||||||
|
|
||||||
|
|
@ -39,6 +80,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_copy_buffer(
|
||||||
size_t src_offset,
|
size_t src_offset,
|
||||||
size_t dst_offset);
|
size_t dst_offset);
|
||||||
|
|
||||||
|
// C callback function type for command block execution
|
||||||
|
typedef void (*aoti_torch_mps_command_block_callback_t)(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
void* user_data);
|
||||||
|
|
||||||
|
// Shared callback function for std::function trampoline
|
||||||
|
AOTI_TORCH_EXPORT void aoti_torch_mps_shared_callback(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
void* user_data);
|
||||||
|
|
||||||
|
// Pure C version using function pointer and user data for trampoline pattern
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_run_command_block(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
aoti_torch_mps_command_block_callback_t callback,
|
||||||
|
void* user_data);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -27,3 +27,116 @@ AOTITorchError aoti_torch_mps_set_arg_int(
|
||||||
func->setArg(idx, val);
|
func->setArg(idx, val);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_create_shader_library(
|
||||||
|
const char* metal_shader_source,
|
||||||
|
AOTIMetalShaderLibraryHandle* library_handle) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* library = new at::native::mps::DynamicMetalShaderLibrary(
|
||||||
|
std::string(metal_shader_source));
|
||||||
|
*library_handle = reinterpret_cast<AOTIMetalShaderLibraryHandle>(library);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_delete_shader_library(
|
||||||
|
AOTIMetalShaderLibraryHandle library_handle) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* library =
|
||||||
|
reinterpret_cast<at::native::mps::MetalShaderLibrary*>(library_handle);
|
||||||
|
delete library;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_get_kernel_function(
|
||||||
|
AOTIMetalShaderLibraryHandle library_handle,
|
||||||
|
const char* kernel_name,
|
||||||
|
AOTIMetalKernelFunctionHandle* function_handle) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* library =
|
||||||
|
reinterpret_cast<at::native::mps::MetalShaderLibrary*>(library_handle);
|
||||||
|
auto* function =
|
||||||
|
library->getCachedKernelFunctionPtr(std::string(kernel_name));
|
||||||
|
*function_handle =
|
||||||
|
reinterpret_cast<AOTIMetalKernelFunctionHandle>(function);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_start_encoding(
|
||||||
|
AOTIMetalKernelFunctionHandle func) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* function_ptr =
|
||||||
|
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||||
|
function_ptr->startEncoding();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_dispatch_single(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
uint64_t length) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* function_ptr =
|
||||||
|
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||||
|
function_ptr->dispatch(length);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
uint64_t length,
|
||||||
|
uint64_t group_size) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* function_ptr =
|
||||||
|
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||||
|
function_ptr->dispatch(length, group_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_dispatch_array(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
const uint64_t* length,
|
||||||
|
size_t length_size) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* function_ptr =
|
||||||
|
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||||
|
c10::ArrayRef<uint64_t> length_ref(length, length_size);
|
||||||
|
function_ptr->dispatch(length_ref);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
const uint64_t* length,
|
||||||
|
size_t length_size,
|
||||||
|
const uint64_t* group_size,
|
||||||
|
size_t group_size_size) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* function_ptr =
|
||||||
|
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||||
|
c10::ArrayRef<uint64_t> length_ref(length, length_size);
|
||||||
|
c10::ArrayRef<uint64_t> group_size_ref(group_size, group_size_size);
|
||||||
|
function_ptr->dispatch(length_ref, group_size_ref);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shared callback function for std::function trampoline
|
||||||
|
void aoti_torch_mps_shared_callback(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
void* user_data) {
|
||||||
|
auto* function_wrapper =
|
||||||
|
static_cast<std::function<void(AOTIMetalKernelFunctionHandle)>*>(
|
||||||
|
user_data);
|
||||||
|
(*function_wrapper)(func);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pure C version using function pointer and user data for trampoline pattern
|
||||||
|
AOTITorchError aoti_torch_mps_run_command_block(
|
||||||
|
AOTIMetalKernelFunctionHandle func,
|
||||||
|
aoti_torch_mps_command_block_callback_t callback,
|
||||||
|
void* user_data) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
auto* function_ptr =
|
||||||
|
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||||
|
function_ptr->runCommandBlock(
|
||||||
|
[callback, func, user_data]() { callback(func, user_data); });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
#include <ATen/native/mps/MetalShaderLibrary.h>
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
|
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
|
||||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||||
|
|
@ -6,7 +5,6 @@
|
||||||
#include <ATen/mps/MPSStream.h>
|
#include <ATen/mps/MPSStream.h>
|
||||||
#include <ATen/mps/MPSProfiler.h>
|
#include <ATen/mps/MPSProfiler.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace torch::aot_inductor;
|
using namespace torch::aot_inductor;
|
||||||
|
|
||||||
AOTITorchError aoti_torch_mps_malloc(
|
AOTITorchError aoti_torch_mps_malloc(
|
||||||
|
|
@ -33,7 +31,6 @@ AOTITorchError aoti_torch_mps_free(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
AOTITorchError
|
AOTITorchError
|
||||||
aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) {
|
aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) {
|
||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
|
@ -46,7 +43,6 @@ aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, s
|
||||||
AOTITorchError
|
AOTITorchError
|
||||||
aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) {
|
aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) {
|
||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
|
||||||
auto src_mtl_buffer = (id<MTLBuffer>)src_buffer;
|
auto src_mtl_buffer = (id<MTLBuffer>)src_buffer;
|
||||||
auto dst_mtl_buffer = (id<MTLBuffer>)dst_buffer;
|
auto dst_mtl_buffer = (id<MTLBuffer>)dst_buffer;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user