pytorch/test/test_cpp_extensions.py
Peter Goldsborough 6f2307ba6a Allow building libraries with setuptools that dont have abi suffix (#14130)
Summary:
When using `setuptools` to build a Python extension, setuptools will automatically add an ABI suffix like `cpython-37m-x86_64-linux-gnu` to the shared library name when using Python 3. This is required for extensions meant to be imported as Python modules. When we use setuptools to build shared libraries not meant as Python modules, for example libraries that define and register TorchScript custom ops, having your library called `my_ops.cpython-37m-x86_64-linux-gnu.so` is a bit annoying compared to just `my_ops.so`, especially since you have to reference the library name when loading it with `torch.ops.load_library` in Python.

This PR fixes this by adding a `with_options` class method to the `torch.utils.cpp_extension.BuildExtension` which allows configuring the `BuildExtension`. In this case, the first option we add is `no_python_abi_suffix`, which we then use in `get_ext_filename` (override from `setuptools.build_ext`) to throw away the ABI suffix.

I've added a test `setup.py` in a `no_python_abi_suffix_test` folder.

Fixes https://github.com/pytorch/pytorch/issues/14188

t-vi fmassa soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14130

Differential Revision: D13216575

Pulled By: goldsborough

fbshipit-source-id: 67dc345c1278a1a4ee4ca907d848bc1fb4956cfa
2018-11-27 17:35:53 -08:00

422 lines
15 KiB
Python
Executable File

import os
import shutil
import sys
import unittest
import torch
import torch.utils.cpp_extension
import torch.backends.cudnn
try:
import torch_test_cpp_extension.cpp as cpp_extension
except ImportError:
print("\'test_cpp_extensions.py\' cannot be invoked directly. " +
"Run \'python run_test.py -i cpp_extensions\' for the \'test_cpp_extensions.py\' tests.")
raise
import common_utils as common
from torch.utils.cpp_extension import CUDA_HOME
TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
TEST_CUDNN = False
if TEST_CUDA:
CUDNN_HEADER_EXISTS = os.path.isfile(os.path.join(CUDA_HOME, 'include/cudnn.h'))
TEST_CUDNN = TEST_CUDA and CUDNN_HEADER_EXISTS and torch.backends.cudnn.is_available()
IS_WINDOWS = sys.platform == 'win32'
class TestCppExtension(common.TestCase):
def setUp(self):
if not IS_WINDOWS:
default_build_root = torch.utils.cpp_extension.get_default_build_root()
if os.path.exists(default_build_root):
shutil.rmtree(default_build_root)
def test_extension_function(self):
x = torch.randn(4, 4)
y = torch.randn(4, 4)
z = cpp_extension.sigmoid_add(x, y)
self.assertEqual(z, x.sigmoid() + y.sigmoid())
def test_extension_module(self):
mm = cpp_extension.MatrixMultiplier(4, 8)
weights = torch.rand(8, 4)
expected = mm.get().mm(weights)
result = mm.forward(weights)
self.assertEqual(expected, result)
def test_backward(self):
mm = cpp_extension.MatrixMultiplier(4, 8)
weights = torch.rand(8, 4, requires_grad=True)
result = mm.forward(weights)
result.sum().backward()
tensor = mm.get()
expected_weights_grad = tensor.t().mm(torch.ones([4, 4]))
self.assertEqual(weights.grad, expected_weights_grad)
expected_tensor_grad = torch.ones([4, 4]).mm(weights.t())
self.assertEqual(tensor.grad, expected_tensor_grad)
def test_jit_compile_extension(self):
module = torch.utils.cpp_extension.load(
name='jit_extension',
sources=[
'cpp_extensions/jit_extension.cpp',
'cpp_extensions/jit_extension2.cpp'
],
extra_include_paths=['cpp_extensions'],
extra_cflags=['-g'],
verbose=True)
x = torch.randn(4, 4)
y = torch.randn(4, 4)
z = module.tanh_add(x, y)
self.assertEqual(z, x.tanh() + y.tanh())
# Checking we can call a method defined not in the main C++ file.
z = module.exp_add(x, y)
self.assertEqual(z, x.exp() + y.exp())
# Checking we can use this JIT-compiled class.
doubler = module.Doubler(2, 2)
self.assertIsNone(doubler.get().grad)
self.assertEqual(doubler.get().sum(), 4)
self.assertEqual(doubler.forward().sum(), 8)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_cuda_extension(self):
import torch_test_cpp_extension.cuda as cuda_extension
x = torch.zeros(100, device='cuda', dtype=torch.float32)
y = torch.zeros(100, device='cuda', dtype=torch.float32)
z = cuda_extension.sigmoid_add(x, y).cpu()
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_jit_cuda_extension(self):
# NOTE: The name of the extension must equal the name of the module.
module = torch.utils.cpp_extension.load(
name='torch_test_cuda_extension',
sources=[
'cpp_extensions/cuda_extension.cpp',
'cpp_extensions/cuda_extension.cu'
],
extra_cuda_cflags=['-O2'],
verbose=True)
x = torch.zeros(100, device='cuda', dtype=torch.float32)
y = torch.zeros(100, device='cuda', dtype=torch.float32)
z = module.sigmoid_add(x, y).cpu()
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not TEST_CUDNN, "CuDNN not found")
def test_jit_cudnn_extension(self):
# implementation of CuDNN ReLU
if IS_WINDOWS:
extra_ldflags = ['cudnn.lib']
else:
extra_ldflags = ['-lcudnn']
module = torch.utils.cpp_extension.load(
name='torch_test_cudnn_extension',
sources=[
'cpp_extensions/cudnn_extension.cpp'
],
extra_ldflags=extra_ldflags,
verbose=True,
with_cuda=True)
x = torch.randn(100, device='cuda', dtype=torch.float32)
y = torch.zeros(100, device='cuda', dtype=torch.float32)
module.cudnn_relu(x, y) # y=relu(x)
self.assertEqual(torch.nn.functional.relu(x), y)
with self.assertRaisesRegex(RuntimeError, "same size"):
y_incorrect = torch.zeros(20, device='cuda', dtype=torch.float32)
module.cudnn_relu(x, y_incorrect)
def test_optional(self):
has_value = cpp_extension.function_taking_optional(torch.ones(5))
self.assertTrue(has_value)
has_value = cpp_extension.function_taking_optional(None)
self.assertFalse(has_value)
def test_inline_jit_compile_extension_with_functions_as_list(self):
cpp_source = '''
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
return x.tanh() + y.tanh();
}
'''
module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension_with_functions_list',
cpp_sources=cpp_source,
functions='tanh_add',
verbose=True)
self.assertEqual(module.tanh_add.__doc__.split('\n')[2], 'tanh_add')
x = torch.randn(4, 4)
y = torch.randn(4, 4)
z = module.tanh_add(x, y)
self.assertEqual(z, x.tanh() + y.tanh())
def test_inline_jit_compile_extension_with_functions_as_dict(self):
cpp_source = '''
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
return x.tanh() + y.tanh();
}
'''
module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension_with_functions_dict',
cpp_sources=cpp_source,
functions={'tanh_add': 'Tanh and then sum :D'},
verbose=True)
self.assertEqual(
module.tanh_add.__doc__.split('\n')[2], 'Tanh and then sum :D')
def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
cpp_source1 = '''
torch::Tensor sin_add(torch::Tensor x, torch::Tensor y) {
return x.sin() + y.sin();
}
'''
cpp_source2 = '''
#include <torch/extension.h>
torch::Tensor sin_add(torch::Tensor x, torch::Tensor y);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sin_add", &sin_add, "sin(x) + sin(y)");
}
'''
module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension',
cpp_sources=[cpp_source1, cpp_source2],
verbose=True)
x = torch.randn(4, 4)
y = torch.randn(4, 4)
z = module.sin_add(x, y)
self.assertEqual(z, x.sin() + y.sin())
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_inline_jit_compile_extension_cuda(self):
cuda_source = '''
__global__ void cos_add_kernel(
const float* __restrict__ x,
const float* __restrict__ y,
float* __restrict__ output,
const int size) {
const auto index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
output[index] = __cosf(x[index]) + __cosf(y[index]);
}
}
torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
auto output = torch::zeros_like(x);
const int threads = 1024;
const int blocks = (output.numel() + threads - 1) / threads;
cos_add_kernel<<<blocks, threads>>>(x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
return output;
}
'''
# Here, the C++ source need only declare the function signature.
cpp_source = 'torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);'
module = torch.utils.cpp_extension.load_inline(
name='inline_jit_extension_cuda',
cpp_sources=cpp_source,
cuda_sources=cuda_source,
functions=['cos_add'],
verbose=True)
self.assertEqual(module.cos_add.__doc__.split('\n')[2], 'cos_add')
x = torch.randn(4, 4, device='cuda', dtype=torch.float32)
y = torch.randn(4, 4, device='cuda', dtype=torch.float32)
z = module.cos_add(x, y)
self.assertEqual(z, x.cos() + y.cos())
def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
with self.assertRaises(ValueError):
torch.utils.cpp_extension.load_inline(
name='invalid_jit_extension', cpp_sources='', functions=5)
def test_lenient_flag_handling_in_jit_extensions(self):
cpp_source = '''
torch::Tensor tanh_add(torch::Tensor x, torch::Tensor y) {
return x.tanh() + y.tanh();
}
'''
module = torch.utils.cpp_extension.load_inline(
name='lenient_flag_handling_extension',
cpp_sources=cpp_source,
functions='tanh_add',
extra_cflags=['-g\n\n', '-O0 -Wall'],
extra_include_paths=[' cpp_extensions\n'],
verbose=True)
x = torch.zeros(100, dtype=torch.float32)
y = torch.zeros(100, dtype=torch.float32)
z = module.tanh_add(x, y).cpu()
self.assertEqual(z, x.tanh() + y.tanh())
def test_complex_registration(self):
module = torch.utils.cpp_extension.load(
name='complex_registration_extension',
sources='cpp_extensions/complex_registration_extension.cpp',
verbose=True)
torch.empty(2, 2, dtype=torch.complex64)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_half_support(self):
'''
Checks for an issue with operator< ambiguity for half when certain
THC headers are included.
See https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
for the corresponding issue.
'''
cuda_source = '''
#include <THC/THCNumerics.cuh>
template<typename T, typename U>
__global__ void half_test_kernel(const T* input, U* output) {
if (input[0] < input[1] || input[0] >= input[1]) {
output[0] = 123;
}
}
torch::Tensor half_test(torch::Tensor input) {
auto output = torch::empty(1, input.options().dtype(torch::kFloat));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
half_test_kernel<scalar_t><<<1, 1>>>(
input.data<scalar_t>(),
output.data<float>());
});
return output;
}
'''
module = torch.utils.cpp_extension.load_inline(
name='half_test_extension',
cpp_sources='torch::Tensor half_test(torch::Tensor input);',
cuda_sources=cuda_source,
functions=['half_test'],
verbose=True)
x = torch.randn(3, device='cuda', dtype=torch.half)
result = module.half_test(x)
self.assertEqual(result[0], 123)
def test_reload_jit_extension(self):
def compile(code):
return torch.utils.cpp_extension.load_inline(
name='reloaded_jit_extension',
cpp_sources=code,
functions='f',
verbose=True)
module = compile('int f() { return 123; }')
self.assertEqual(module.f(), 123)
module = compile('int f() { return 456; }')
self.assertEqual(module.f(), 456)
module = compile('int f() { return 456; }')
self.assertEqual(module.f(), 456)
module = compile('int f() { return 789; }')
self.assertEqual(module.f(), 789)
@unittest.skipIf(IS_WINDOWS, "C++ API not yet supported on Windows")
def test_cpp_api_extension(self):
here = os.path.abspath(__file__)
pytorch_root = os.path.dirname(os.path.dirname(here))
api_include = os.path.join(pytorch_root, 'torch', 'csrc', 'api', 'include')
module = torch.utils.cpp_extension.load(
name='cpp_api_extension',
sources='cpp_extensions/cpp_api_extension.cpp',
extra_include_paths=api_include,
verbose=True)
net = module.Net(3, 5)
self.assertTrue(net.training)
net.eval()
self.assertFalse(net.training)
net.train()
self.assertTrue(net.training)
net.eval()
input = torch.randn(2, 3, dtype=torch.float32)
output = net.forward(input)
self.assertEqual(output, net.forward(input))
self.assertEqual(list(output.shape), [2, 5])
bias = net.get_bias()
self.assertEqual(list(bias.shape), [5])
net.set_bias(bias + 1)
self.assertEqual(net.get_bias(), bias + 1)
output2 = net.forward(input)
self.assertNotEqual(output + 1, output2)
self.assertEqual(len(net.parameters()), 4)
p = net.named_parameters()
self.assertEqual(len(p), 4)
self.assertIn('fc.weight', p)
self.assertIn('fc.bias', p)
self.assertIn('bn.weight', p)
self.assertIn('bn.bias', p)
def test_returns_shared_library_path_when_is_python_module_is_true(self):
source = '''
#include <torch/script.h>
torch::Tensor func(torch::Tensor x) { return x; }
static torch::jit::RegisterOperators r("test::func", &func);
'''
torch.utils.cpp_extension.load_inline(
name="is_python_module",
cpp_sources=source,
functions="func",
verbose=True,
is_python_module=False)
self.assertEqual(torch.ops.test.func(torch.eye(5)), torch.eye(5))
@unittest.skipIf(IS_WINDOWS, "Not available on Windows")
def test_no_python_abi_suffix_sets_the_correct_library_name(self):
# For this test, run_test.py will call `python setup.py install` in the
# cpp_extensions/no_python_abi_suffix_test folder, where the
# `BuildExtension` class has a `no_python_abi_suffix` option set to
# `True`. This *should* mean that on Python 3, the produced shared
# library does not have an ABI suffix like
# "cpython-37m-x86_64-linux-gnu" before the library suffix, e.g. "so".
# On Python 2 there is no ABI suffix anyway.
root = os.path.join("cpp_extensions", "no_python_abi_suffix_test", "build")
print(list(os.walk(os.path.join("cpp_extensions", "no_python_abi_suffix_test"))))
matches = [f for _, _, fs in os.walk(root) for f in fs if f.endswith("so")]
self.assertEqual(len(matches), 1, str(matches))
self.assertEqual(matches[0], "no_python_abi_suffix_test.so", str(matches))
if __name__ == '__main__':
common.run_tests()