mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This allows one to do something like that
```python
import torch
x = torch.ones(10, device="mps")
m = torch.mps._compile_shader("""
kernel void foo(device float* x, uint idx [[thread_position_in_grid]]) {
x[idx] += idx;
}
")
m.foo(x)
```
And in general enables writing custom operators using Metal shaders purely in Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141478
Approved by: https://github.com/manuelcandales
11 lines
173 B
C++
11 lines
173 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
namespace torch::mps {
|
|
|
|
PyMethodDef* python_functions();
|
|
void initModule(PyObject* module);
|
|
|
|
} // namespace torch::mps
|