mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR adds support for convenient CUDA integration in our C++ extension mechanism. This mainly involved figuring out how to get setuptools to use nvcc for CUDA files and the regular C++ compiler for C++ files. I've added a mixed C++/CUDA test case which works great. I've also added a CUDAExtension and CppExtension function that constructs a setuptools.Extension with "usually the right" arguments, which reduces the required boilerplate to write an extension even more. Especially for CUDA, where library_dir (CUDA_HOME/lib64) and libraries (cudart) have to be specified as well. Next step is to enable this with our "JIT" mechanism. NOTE: I've had to write a small find_cuda_home function to find the CUDA install directory. This logic is kind of a duplicate of tools/setup_helpers/cuda.py, but that's not available in the shipped PyTorch distribution. The function is also fairly short. Let me know if it's fine to duplicate this logic. * CUDA support for C++ extensions with setuptools * Remove printf in CUDA test kernel * Remove -arch flag in test/cpp_extensions/setup.py * Put wrap_compile into BuildExtension * Add guesses for CUDA_HOME directory * export PATH to CUDA location in test.sh * On Python2, sys.platform has the linux version number
56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
import unittest
|
|
|
|
import torch
|
|
import torch.utils.cpp_extension
|
|
import torch_test_cpp_extensions as cpp_extension
|
|
|
|
import common
|
|
|
|
TEST_CUDA = torch.cuda.is_available()
|
|
|
|
|
|
class TestCppExtension(common.TestCase):
|
|
def test_extension_function(self):
|
|
x = torch.randn(4, 4)
|
|
y = torch.randn(4, 4)
|
|
z = cpp_extension.sigmoid_add(x, y)
|
|
self.assertEqual(z, x.sigmoid() + y.sigmoid())
|
|
|
|
def test_extension_module(self):
|
|
mm = cpp_extension.MatrixMultiplier(4, 8)
|
|
weights = torch.rand(8, 4)
|
|
expected = mm.get().mm(weights)
|
|
result = mm.forward(weights)
|
|
self.assertEqual(expected, result)
|
|
|
|
def test_jit_compile_extension(self):
|
|
module = torch.utils.cpp_extension.load(
|
|
name='jit_extension',
|
|
sources=[
|
|
'cpp_extensions/jit_extension.cpp',
|
|
'cpp_extensions/jit_extension2.cpp'
|
|
],
|
|
extra_include_paths=['cpp_extensions'],
|
|
extra_cflags=['-g'],
|
|
verbose=True)
|
|
x = torch.randn(4, 4)
|
|
y = torch.randn(4, 4)
|
|
z = module.tanh_add(x, y)
|
|
self.assertEqual(z, x.tanh() + y.tanh())
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
|
def test_cuda_extension(self):
|
|
import torch_test_cuda_extension as cuda_extension
|
|
|
|
x = torch.FloatTensor(100).zero_().cuda()
|
|
y = torch.FloatTensor(100).zero_().cuda()
|
|
|
|
z = cuda_extension.sigmoid_add(x, y).cpu()
|
|
|
|
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
|
self.assertEqual(z, torch.ones_like(z))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
common.run_tests()
|