Inline JIT C++ Extensions (#7059)

Adds ability to JIT compile C++ extensions from strings

>>> from torch.utils.cpp_extension import load_inline
>>> source = '''
    at::Tensor sin_add(at::Tensor x, at::Tensor y) {
      return x.sin() + y.sin();
    }
'''
>>> module = load_inline(name='inline_extension', cpp_sources=source, functions='sin_add')
Fixes #7012

* Inline JIT C++ Extensions

* jit_compile_sources -> jit_compile

* Split up test into CUDA and non-CUDA parts

* Documentation fixes

* Implement prologue and epilogue generation

* Remove extra newline

* Only create the CUDA source file when cuda_sources is passed
This commit is contained in:
Peter Goldsborough 2018-04-30 08:48:44 -07:00 committed by Edward Z. Yang
parent c5978db094
commit b70b7a80d4
3 changed files with 218 additions and 13 deletions

View File

@ -6,6 +6,7 @@ torch.utils.cpp_extension
.. autofunction:: CUDAExtension
.. autofunction:: BuildExtension
.. autofunction:: load
.. autofunction:: load_inline
.. autofunction:: include_paths
.. autofunction:: check_compiler_abi_compatibility
.. autofunction:: verify_ninja_availability

View File

@ -72,8 +72,8 @@ class TestCppExtension(common.TestCase):
def test_cuda_extension(self):
import torch_test_cpp_extension.cuda as cuda_extension
x = torch.FloatTensor(100).zero_().cuda()
y = torch.FloatTensor(100).zero_().cuda()
x = torch.zeros(100, device='cuda', dtype=torch.float32)
y = torch.zeros(100, device='cuda', dtype=torch.float32)
z = cuda_extension.sigmoid_add(x, y).cpu()
@ -92,8 +92,8 @@ class TestCppExtension(common.TestCase):
extra_cuda_cflags=['-O2'],
verbose=True)
x = torch.FloatTensor(100).zero_().cuda()
y = torch.FloatTensor(100).zero_().cuda()
x = torch.zeros(100, device='cuda', dtype=torch.float32)
y = torch.zeros(100, device='cuda', dtype=torch.float32)
z = module.sigmoid_add(x, y).cpu()
@ -106,6 +106,111 @@ class TestCppExtension(common.TestCase):
has_value = cpp_extension.function_taking_optional(None)
self.assertFalse(has_value)
def test_inline_jit_compile_extension_with_functions_as_list(self):
cpp_source = '''
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
return x.tanh() + y.tanh();
}
'''
if __name__ == '__main__':
common.run_tests()
module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension_with_functions_list',
cpp_sources=cpp_source,
functions='tanh_add',
verbose=True)
self.assertEqual(module.tanh_add.__doc__.split('\n')[2], 'tanh_add')
x = torch.randn(4, 4)
y = torch.randn(4, 4)
z = module.tanh_add(x, y)
self.assertEqual(z, x.tanh() + y.tanh())
def test_inline_jit_compile_extension_with_functions_as_dict(self):
cpp_source = '''
at::Tensor tanh_add(at::Tensor x, at::Tensor y) {
return x.tanh() + y.tanh();
}
'''
module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension_with_functions_dict',
cpp_sources=cpp_source,
functions={'tanh_add': 'Tanh and then sum :D'},
verbose=True)
self.assertEqual(
module.tanh_add.__doc__.split('\n')[2], 'Tanh and then sum :D')
def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
cpp_source1 = '''
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
return x.sin() + y.sin();
}
'''
cpp_source2 = '''
#include <torch/torch.h>
at::Tensor sin_add(at::Tensor x, at::Tensor y);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sin_add", &sin_add, "sin(x) + sin(y)");
}
'''
module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension',
cpp_sources=[cpp_source1, cpp_source2],
verbose=True)
x = torch.randn(4, 4)
y = torch.randn(4, 4)
z = module.sin_add(x, y)
self.assertEqual(z, x.sin() + y.sin())
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_inline_jit_compile_extension_cuda(self):
cuda_source = '''
__global__ void cos_add_kernel(
const float* __restrict__ x,
const float* __restrict__ y,
float* __restrict__ output,
const int size) {
const auto index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
output[index] = __cosf(x[index]) + __cosf(y[index]);
}
}
at::Tensor cos_add(at::Tensor x, at::Tensor y) {
auto output = at::zeros_like(x);
const int threads = 1024;
const int blocks = (output.numel() + threads - 1) / threads;
cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
return output;
}
'''
# Here, the C++ source need only declare the function signature.
cpp_source = 'at::Tensor cos_add(at::Tensor x, at::Tensor y);'
module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension_cuda',
cpp_sources=cpp_source,
cuda_sources=cuda_source,
functions=['cos_add'],
verbose=True)
self.assertEqual(module.cos_add.__doc__.split('\n')[2], 'cos_add')
x = torch.randn(4, 4, device='cuda', dtype=torch.float32)
y = torch.randn(4, 4, device='cuda', dtype=torch.float32)
z = module.cos_add(x, y)
self.assertEqual(z, x.cos() + y.cos())
def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
with self.assertRaises(ValueError):
torch.utils.cpp_extension.load_inline(
name='invalid_jit_extension', cpp_sources='', functions=5)

View File

@ -467,20 +467,119 @@ def load(name,
extra_cflags=['-O2'],
verbose=True)
'''
return _jit_compile(
name,
[sources] if isinstance(sources, str) else sources,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
build_directory or _get_build_directory(name, verbose),
verbose)
verify_ninja_availability()
# Allows sources to be a single path or a list of paths.
if isinstance(sources, str):
sources = [sources]
def load_inline(name,
cpp_sources,
cuda_sources=None,
functions=None,
extra_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
extra_include_paths=None,
build_directory=None,
verbose=False):
'''
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
if build_directory is None:
build_directory = _get_build_directory(name, verbose)
This function behaves exactly like :func:`load`, but takes its sources as
strings rather than filenames. These strings are stored to files in the
build directory, after which the behavior of :func:`load_inline` is
identical to :func:`load`. Strings passed in ``cpp_sources`` (a string or
list of strings) are stored with a ``.cpp`` extension, and the string or list
of strings passed in ``cuda_sources`` are stored with a ``.cu`` extension.
Example:
>>> from torch.utils.cpp_extension import load_inline
>>> source = \'\'\'
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
return x.sin() + y.sin();
}
\'\'\'
>>> module = load_inline(name='inline_extension',
cpp_sources=[source],
functions=['sin_add'])
'''
build_directory = build_directory or _get_build_directory(name, verbose)
source_files = []
if isinstance(cpp_sources, str):
cpp_sources = [cpp_sources]
cuda_sources = cuda_sources or []
if isinstance(cuda_sources, str):
cuda_sources = [cuda_sources]
cpp_sources.insert(0, '#include <torch/torch.h>')
# If `functions` is supplied, we create the pybind11 bindings for the user.
# Here, `functions` is (or becomes, after some processing) a map from
# function names to function docstrings.
if functions is not None:
cpp_sources.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
if isinstance(functions, str):
functions = [functions]
if isinstance(functions, list):
# Make the function docstring the same as the function name.
functions = dict((f, f) for f in functions)
elif not isinstance(functions, dict):
raise ValueError(
"Expected 'functions' to be a list or dict, but was {}".format(
type(functions)))
for function_name, docstring in functions.items():
cpp_sources.append('m.def("{0}", &{0}, "{1}");'.format(
function_name, docstring))
cpp_sources.append('}')
cpp_source_path = os.path.join(build_directory, 'main.cpp')
with open(cpp_source_path, 'w') as cpp_source_file:
cpp_source_file.write('\n'.join(cpp_sources))
sources = [cpp_source_path]
if cuda_sources:
cuda_sources.insert(0, '#include <ATen/ATen.h>')
cuda_sources.insert(1, '#include <cuda.h>')
cuda_sources.insert(2, '#include <cuda_runtime.h>')
cuda_source_path = os.path.join(build_directory, 'cuda.cu')
with open(cuda_source_path, 'w') as cuda_source_file:
cuda_source_file.write('\n'.join(cuda_sources))
sources.append(cuda_source_path)
return _jit_compile(
name,
sources,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
build_directory,
verbose)
def _jit_compile(name,
sources,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
build_directory,
verbose):
baton = FileBaton(os.path.join(build_directory, 'lock'))
if baton.try_acquire():
try:
verify_ninja_availability()
check_compiler_abi_compatibility(os.environ.get('CXX', 'c++'))
with_cuda = any(map(_is_cuda_file, sources))
extra_ldflags = _prepare_ldflags(