mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Inline JIT C++ Extensions (#7059)
Adds ability to JIT compile C++ extensions from strings
>>> from torch.utils.cpp_extension import load_inline
>>> source = '''
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
return x.sin() + y.sin();
}
'''
>>> module = load_inline(name='inline_extension', cpp_sources=source, functions='sin_add')
Fixes #7012
* Inline JIT C++ Extensions
* jit_compile_sources -> jit_compile
* Split up test into CUDA and non-CUDA parts
* Documentation fixes
* Implement prologue and epilogue generation
* Remove extra newline
* Only create the CUDA source file when cuda_sources is passed
This commit is contained in:
parent
c5978db094
commit
b70b7a80d4
|
|
@ -6,6 +6,7 @@ torch.utils.cpp_extension
|
|||
.. autofunction:: CUDAExtension
|
||||
.. autofunction:: BuildExtension
|
||||
.. autofunction:: load
|
||||
.. autofunction:: load_inline
|
||||
.. autofunction:: include_paths
|
||||
.. autofunction:: check_compiler_abi_compatibility
|
||||
.. autofunction:: verify_ninja_availability
|
||||
|
|
|
|||
|
|
@ -72,8 +72,8 @@ class TestCppExtension(common.TestCase):
|
|||
def test_cuda_extension(self):
|
||||
import torch_test_cpp_extension.cuda as cuda_extension
|
||||
|
||||
x = torch.FloatTensor(100).zero_().cuda()
|
||||
y = torch.FloatTensor(100).zero_().cuda()
|
||||
x = torch.zeros(100, device='cuda', dtype=torch.float32)
|
||||
y = torch.zeros(100, device='cuda', dtype=torch.float32)
|
||||
|
||||
z = cuda_extension.sigmoid_add(x, y).cpu()
|
||||
|
||||
|
|
@ -92,8 +92,8 @@ class TestCppExtension(common.TestCase):
|
|||
extra_cuda_cflags=['-O2'],
|
||||
verbose=True)
|
||||
|
||||
x = torch.FloatTensor(100).zero_().cuda()
|
||||
y = torch.FloatTensor(100).zero_().cuda()
|
||||
x = torch.zeros(100, device='cuda', dtype=torch.float32)
|
||||
y = torch.zeros(100, device='cuda', dtype=torch.float32)
|
||||
|
||||
z = module.sigmoid_add(x, y).cpu()
|
||||
|
||||
|
|
@ -106,6 +106,111 @@ class TestCppExtension(common.TestCase):
|
|||
has_value = cpp_extension.function_taking_optional(None)
|
||||
self.assertFalse(has_value)
|
||||
|
||||
def test_inline_jit_compile_extension_with_functions_as_list(self):
|
||||
cpp_source = '''
|
||||
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
|
||||
return x.tanh() + y.tanh();
|
||||
}
|
||||
'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
common.run_tests()
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name='inline_jit_extension_with_functions_list',
|
||||
cpp_sources=cpp_source,
|
||||
functions='tanh_add',
|
||||
verbose=True)
|
||||
|
||||
self.assertEqual(module.tanh_add.__doc__.split('\n')[2], 'tanh_add')
|
||||
|
||||
x = torch.randn(4, 4)
|
||||
y = torch.randn(4, 4)
|
||||
|
||||
z = module.tanh_add(x, y)
|
||||
self.assertEqual(z, x.tanh() + y.tanh())
|
||||
|
||||
def test_inline_jit_compile_extension_with_functions_as_dict(self):
|
||||
cpp_source = '''
|
||||
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
|
||||
return x.tanh() + y.tanh();
|
||||
}
|
||||
'''
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name='inline_jit_extension_with_functions_dict',
|
||||
cpp_sources=cpp_source,
|
||||
functions={'tanh_add': 'Tanh and then sum :D'},
|
||||
verbose=True)
|
||||
|
||||
self.assertEqual(
|
||||
module.tanh_add.__doc__.split('\n')[2], 'Tanh and then sum :D')
|
||||
|
||||
def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
|
||||
cpp_source1 = '''
|
||||
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
|
||||
return x.sin() + y.sin();
|
||||
}
|
||||
'''
|
||||
|
||||
cpp_source2 = '''
|
||||
#include <torch/torch.h>
|
||||
at::Tensor sin_add(at::Tensor x, at::Tensor y);
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("sin_add", &sin_add, "sin(x) + sin(y)");
|
||||
}
|
||||
'''
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name='inline_jit_extension',
|
||||
cpp_sources=[cpp_source1, cpp_source2],
|
||||
verbose=True)
|
||||
|
||||
x = torch.randn(4, 4)
|
||||
y = torch.randn(4, 4)
|
||||
|
||||
z = module.sin_add(x, y)
|
||||
self.assertEqual(z, x.sin() + y.sin())
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
def test_inline_jit_compile_extension_cuda(self):
|
||||
cuda_source = '''
|
||||
__global__ void cos_add_kernel(
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ y,
|
||||
float* __restrict__ output,
|
||||
const int size) {
|
||||
const auto index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < size) {
|
||||
output[index] = __cosf(x[index]) + __cosf(y[index]);
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor cos_add(at::Tensor x, at::Tensor y) {
|
||||
auto output = at::zeros_like(x);
|
||||
const int threads = 1024;
|
||||
const int blocks = (output.numel() + threads - 1) / threads;
|
||||
cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
||||
return output;
|
||||
}
|
||||
'''
|
||||
|
||||
# Here, the C++ source need only declare the function signature.
|
||||
cpp_source = 'at::Tensor cos_add(at::Tensor x, at::Tensor y);'
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name='inline_jit_extension_cuda',
|
||||
cpp_sources=cpp_source,
|
||||
cuda_sources=cuda_source,
|
||||
functions=['cos_add'],
|
||||
verbose=True)
|
||||
|
||||
self.assertEqual(module.cos_add.__doc__.split('\n')[2], 'cos_add')
|
||||
|
||||
x = torch.randn(4, 4, device='cuda', dtype=torch.float32)
|
||||
y = torch.randn(4, 4, device='cuda', dtype=torch.float32)
|
||||
|
||||
z = module.cos_add(x, y)
|
||||
self.assertEqual(z, x.cos() + y.cos())
|
||||
|
||||
def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
|
||||
with self.assertRaises(ValueError):
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
name='invalid_jit_extension', cpp_sources='', functions=5)
|
||||
|
|
|
|||
|
|
@ -467,20 +467,119 @@ def load(name,
|
|||
extra_cflags=['-O2'],
|
||||
verbose=True)
|
||||
'''
|
||||
return _jit_compile(
|
||||
name,
|
||||
[sources] if isinstance(sources, str) else sources,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory or _get_build_directory(name, verbose),
|
||||
verbose)
|
||||
|
||||
verify_ninja_availability()
|
||||
|
||||
# Allows sources to be a single path or a list of paths.
|
||||
if isinstance(sources, str):
|
||||
sources = [sources]
|
||||
def load_inline(name,
|
||||
cpp_sources,
|
||||
cuda_sources=None,
|
||||
functions=None,
|
||||
extra_cflags=None,
|
||||
extra_cuda_cflags=None,
|
||||
extra_ldflags=None,
|
||||
extra_include_paths=None,
|
||||
build_directory=None,
|
||||
verbose=False):
|
||||
'''
|
||||
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
|
||||
|
||||
if build_directory is None:
|
||||
build_directory = _get_build_directory(name, verbose)
|
||||
This function behaves exactly like :func:`load`, but takes its sources as
|
||||
strings rather than filenames. These strings are stored to files in the
|
||||
build directory, after which the behavior of :func:`load_inline` is
|
||||
identical to :func:`load`. Strings passed in ``cpp_sources`` (a string or
|
||||
list of strings) are stored with a ``.cpp`` extension, and the string or list
|
||||
of strings passed in ``cuda_sources`` are stored with a ``.cu`` extension.
|
||||
|
||||
Example:
|
||||
>>> from torch.utils.cpp_extension import load_inline
|
||||
>>> source = \'\'\'
|
||||
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
|
||||
return x.sin() + y.sin();
|
||||
}
|
||||
\'\'\'
|
||||
>>> module = load_inline(name='inline_extension',
|
||||
cpp_sources=[source],
|
||||
functions=['sin_add'])
|
||||
'''
|
||||
build_directory = build_directory or _get_build_directory(name, verbose)
|
||||
|
||||
source_files = []
|
||||
|
||||
if isinstance(cpp_sources, str):
|
||||
cpp_sources = [cpp_sources]
|
||||
cuda_sources = cuda_sources or []
|
||||
if isinstance(cuda_sources, str):
|
||||
cuda_sources = [cuda_sources]
|
||||
|
||||
cpp_sources.insert(0, '#include <torch/torch.h>')
|
||||
|
||||
# If `functions` is supplied, we create the pybind11 bindings for the user.
|
||||
# Here, `functions` is (or becomes, after some processing) a map from
|
||||
# function names to function docstrings.
|
||||
if functions is not None:
|
||||
cpp_sources.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
|
||||
if isinstance(functions, str):
|
||||
functions = [functions]
|
||||
if isinstance(functions, list):
|
||||
# Make the function docstring the same as the function name.
|
||||
functions = dict((f, f) for f in functions)
|
||||
elif not isinstance(functions, dict):
|
||||
raise ValueError(
|
||||
"Expected 'functions' to be a list or dict, but was {}".format(
|
||||
type(functions)))
|
||||
for function_name, docstring in functions.items():
|
||||
cpp_sources.append('m.def("{0}", &{0}, "{1}");'.format(
|
||||
function_name, docstring))
|
||||
cpp_sources.append('}')
|
||||
|
||||
cpp_source_path = os.path.join(build_directory, 'main.cpp')
|
||||
with open(cpp_source_path, 'w') as cpp_source_file:
|
||||
cpp_source_file.write('\n'.join(cpp_sources))
|
||||
|
||||
sources = [cpp_source_path]
|
||||
|
||||
if cuda_sources:
|
||||
cuda_sources.insert(0, '#include <ATen/ATen.h>')
|
||||
cuda_sources.insert(1, '#include <cuda.h>')
|
||||
cuda_sources.insert(2, '#include <cuda_runtime.h>')
|
||||
|
||||
cuda_source_path = os.path.join(build_directory, 'cuda.cu')
|
||||
with open(cuda_source_path, 'w') as cuda_source_file:
|
||||
cuda_source_file.write('\n'.join(cuda_sources))
|
||||
|
||||
sources.append(cuda_source_path)
|
||||
|
||||
return _jit_compile(
|
||||
name,
|
||||
sources,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory,
|
||||
verbose)
|
||||
|
||||
|
||||
def _jit_compile(name,
|
||||
sources,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory,
|
||||
verbose):
|
||||
baton = FileBaton(os.path.join(build_directory, 'lock'))
|
||||
|
||||
if baton.try_acquire():
|
||||
try:
|
||||
verify_ninja_availability()
|
||||
check_compiler_abi_compatibility(os.environ.get('CXX', 'c++'))
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
extra_ldflags = _prepare_ldflags(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user