mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR allows user to author a CUDA kernel in python.
```
from torch.cuda.jiterator import create_jit_fn
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x * y + x - y + alpha; }"
jitted_fn = create_jit_fn(code_string, alpha=0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
result = jitted_fn(a, b, alpha=1.0)
```
Limitations:
- Only supports elementwise kernel
- 1~8 tensor inputs (empty input, e.g. factory methods, is not supported)
- inputs tensors must live in cuda device
- cpu Scalar is not supported
- kwargs must be pre-declared when calling create_jit_fn
- kwargs must be convertible to at::Scalar, one of float64, int64_t, bool. (complex not support for now)
TODOs:
- [x] consolidate union and c10::variant implementation
- [x] plug into existing op testing framework
- [ ] rename files, place files in the right folder
- [ ] place util functions in the right file
- [x] enforce assumptions in python interface e.g <8 inputs, kwargs types
- [x] Add user-facing documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76394
Approved by: https://github.com/mruberry
105 lines
3.9 KiB
Python
105 lines
3.9 KiB
Python
import torch
|
|
from torch import Tensor
|
|
from typing import Callable, List
|
|
|
|
import re
|
|
|
|
__all__ : List[str] = []
|
|
|
|
class _CodeParser:
|
|
def __init__(self, code_string: str):
|
|
optional_ws = r"\s*"
|
|
required_ws = r"\s+"
|
|
template_params = r"(?P<template_params>\<.+\>)"
|
|
return_type = r"(?P<return_type>\w+)"
|
|
function_name = r"(?P<function_name>\w+)"
|
|
function_params = r"(?P<function_params>\(.+\))"
|
|
function_body = r"(?P<function_body>\{.+\})"
|
|
|
|
pattern = \
|
|
optional_ws \
|
|
+ "template" \
|
|
+ optional_ws + template_params \
|
|
+ optional_ws + return_type \
|
|
+ required_ws + function_name \
|
|
+ optional_ws + function_params \
|
|
+ optional_ws + function_body \
|
|
+ optional_ws
|
|
|
|
result = re.match(pattern, code_string, re.DOTALL) # DOTALL for matching multiline
|
|
|
|
if result is None:
|
|
raise Exception(f"Couldn't parse code, please check correctness:\n {code_string}")
|
|
|
|
self.template_params = result["template_params"]
|
|
self.return_type = result["return_type"]
|
|
self.function_name = result["function_name"]
|
|
self.function_params = result["function_params"]
|
|
self.function_body = result["function_body"]
|
|
|
|
|
|
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
|
|
"""
|
|
Create a jiterator-generated cuda kernel for an elementwise op.
|
|
|
|
The code string has to be a valid CUDA function that describes the computation for a single element. The code
|
|
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
|
|
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
|
|
local temp dir.
|
|
|
|
Jiterator-generated kernels accepts noncontiguous tensors, and supports boardcasting and type promotion.
|
|
|
|
Args:
|
|
code_string (string): CUDA code string to be compiled by jiterator.
|
|
kwargs (Dict, optional): Keyword arguments for generated function
|
|
|
|
Examples:
|
|
>>> code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
|
|
>>> jitted_fn = create_jit_fn(code_string, alpha=1.0)
|
|
>>> a = torch.rand(3, device='cuda')
|
|
>>> b = torch.rand(3, device='cuda')
|
|
>>> # invoke jitted function like a regular python function
|
|
>>> result = jitted_fn(a, b, alpha=3.14)
|
|
|
|
.. warning::
|
|
This API is in beta and may change in future releases.
|
|
|
|
.. warning::
|
|
Jiterator only supports up to 8 tensor inputs
|
|
|
|
.. warning::
|
|
All input tensors must live in CUDA device
|
|
|
|
"""
|
|
class JittedFunction:
|
|
def __init__(self, code_string: str, **kwargs):
|
|
self.code_string = code_string
|
|
|
|
parsed_code = _CodeParser(code_string)
|
|
self.kernel_name = parsed_code.function_name
|
|
|
|
self.kwargs_dict = kwargs
|
|
self.is_cuda_available = torch.cuda.is_available()
|
|
|
|
def __call__(self, *tensors: Tensor, **kwargs):
|
|
# Jiterator follow torch.cuda's lazy initialization behavior
|
|
# Defer checking cuda's availability at the function invocation time
|
|
assert self.is_cuda_available, "Jiterator is only supported on CUDA GPUs, no CUDA GPUs are available."
|
|
|
|
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
|
|
|
|
expanded_kwargs = self.kwargs_dict.copy()
|
|
for key, value in kwargs.items():
|
|
if key in self.kwargs_dict:
|
|
expanded_kwargs[key] = value
|
|
else:
|
|
raise KeyError(f"{key} is not declared in function definition")
|
|
|
|
return torch._C._cuda_jiterator_compile_and_launch_kernel(
|
|
self.code_string,
|
|
self.kernel_name,
|
|
tensors,
|
|
expanded_kwargs)
|
|
|
|
return JittedFunction(code_string, **kwargs)
|