mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR allows user to author a CUDA kernel in python.
```
from torch.cuda.jiterator import create_jit_fn
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x * y + x - y + alpha; }"
jitted_fn = create_jit_fn(code_string, alpha=0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
result = jitted_fn(a, b, alpha=1.0)
```
Limitations:
- Only supports elementwise kernel
- 1~8 tensor inputs (empty input, e.g. factory methods, is not supported)
- inputs tensors must live in cuda device
- cpu Scalar is not supported
- kwargs must be pre-declared when calling create_jit_fn
- kwargs must be convertible to at::Scalar, one of float64, int64_t, bool. (complex not support for now)
TODOs:
- [x] consolidate union and c10::variant implementation
- [x] plug into existing op testing framework
- [ ] rename files, place files in the right folder
- [ ] place util functions in the right file
- [x] enforce assumptions in python interface e.g <8 inputs, kwargs types
- [x] Add user-facing documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76394
Approved by: https://github.com/mruberry
|
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| complex_numbers.rst | ||
| conf.py | ||
| config_mod.rst | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cuda.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| deploy.rst | ||
| distributed.algorithms.join.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| fft.rst | ||
| fsdp.rst | ||
| futures.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit.rst | ||
| linalg.rst | ||
| math-quantizer-equation.png | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| monitor.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nested.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx_supported_aten_ops.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| pipeline.rst | ||
| profiler.rst | ||
| quantization-backend-configuration.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| torch.ao.ns._numeric_suite_fx.rst | ||
| torch.ao.ns._numeric_suite.rst | ||
| torch.overrides.rst | ||
| torch.rst | ||
| type_info.rst | ||