mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cpp extensions] Create torch.h and update setup.py
This commit is contained in:
parent
6665a45d5e
commit
1262fba8e7
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -57,3 +57,4 @@ test/data/linear.pt
|
|||
.ninja_deps
|
||||
.ninja_log
|
||||
compile_commands.json
|
||||
*.egg-info/
|
||||
|
|
|
|||
60
setup.py
60
setup.py
|
|
@ -348,6 +348,19 @@ class install(setuptools.command.install.install):
|
|||
def run(self):
|
||||
if not self.skip_build:
|
||||
self.run_command('build_deps')
|
||||
|
||||
# Copy include directories necessary to compile C++ extensions.
|
||||
def copy_and_overwrite(src, dst):
|
||||
print('copying {} -> {}'.format(src, dst))
|
||||
if os.path.exists(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst)
|
||||
|
||||
copy_and_overwrite('torch/csrc', 'torch/lib/include/torch/csrc/')
|
||||
copy_and_overwrite('torch/lib/pybind11/include/pybind11/',
|
||||
'torch/lib/include/pybind11')
|
||||
shutil.copy2('torch/torch.h', 'torch/lib/include/torch/torch.h')
|
||||
|
||||
setuptools.command.install.install.run(self)
|
||||
|
||||
|
||||
|
|
@ -736,20 +749,35 @@ cmdclass = {
|
|||
}
|
||||
cmdclass.update(build_dep_cmds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup(name="torch", version=version,
|
||||
description="Tensors and Dynamic neural networks in Python with strong GPU acceleration",
|
||||
ext_modules=extensions,
|
||||
cmdclass=cmdclass,
|
||||
packages=packages,
|
||||
package_data={'torch': [
|
||||
'lib/*.so*', 'lib/*.dylib*', 'lib/*.dll', 'lib/*.lib',
|
||||
'lib/torch_shm_manager',
|
||||
'lib/*.h',
|
||||
'lib/include/TH/*.h', 'lib/include/TH/generic/*.h',
|
||||
'lib/include/THC/*.h', 'lib/include/THC/generic/*.h',
|
||||
'lib/include/ATen/*.h',
|
||||
]},
|
||||
install_requires=['pyyaml', 'numpy'],
|
||||
)
|
||||
setup(
|
||||
name="torch",
|
||||
version=version,
|
||||
description=("Tensors and Dynamic neural networks in "
|
||||
"Python with strong GPU acceleration"),
|
||||
ext_modules=extensions,
|
||||
cmdclass=cmdclass,
|
||||
packages=packages,
|
||||
package_data={
|
||||
'torch': [
|
||||
'lib/*.so*',
|
||||
'lib/*.dylib*',
|
||||
'lib/*.dll',
|
||||
'lib/*.lib',
|
||||
'lib/torch_shm_manager',
|
||||
'lib/*.h',
|
||||
'lib/include/ATen/*.h',
|
||||
'lib/include/pybind11/*.h',
|
||||
'lib/include/pybind11/detail/*.h',
|
||||
'lib/include/TH/*.h',
|
||||
'lib/include/TH/generic/*.h',
|
||||
'lib/include/THC/*.h',
|
||||
'lib/include/THC/generic/*.h',
|
||||
'lib/include/torch/csrc/*.h',
|
||||
'lib/include/torch/csrc/autograd/*.h',
|
||||
'lib/include/torch/csrc/jit/*.h',
|
||||
'lib/include/torch/csrc/utils/*.h',
|
||||
'lib/include/torch/torch.h',
|
||||
]
|
||||
},
|
||||
install_requires=['pyyaml', 'numpy'], )
|
||||
|
|
|
|||
30
test/cpp_extensions/extension.cpp
Normal file
30
test/cpp_extensions/extension.cpp
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
#include <torch/torch.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
Tensor sigmoid_add(Tensor x, Tensor y) {
|
||||
return x.sigmoid() + y.sigmoid();
|
||||
}
|
||||
|
||||
struct MatrixMultiplier {
|
||||
MatrixMultiplier(int A, int B) {
|
||||
tensor_ = CPU(kDouble).ones({A, B});
|
||||
}
|
||||
Tensor forward(Tensor weights) {
|
||||
return tensor_.mm(weights);
|
||||
}
|
||||
Tensor get() const {
|
||||
return tensor_;
|
||||
}
|
||||
|
||||
private:
|
||||
Tensor tensor_;
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(torch_test_cpp_extensions, m) {
|
||||
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
|
||||
py::class_<MatrixMultiplier>(m, "MatrixMultiplier")
|
||||
.def(py::init<int, int>())
|
||||
.def("forward", &MatrixMultiplier::forward)
|
||||
.def("get", &MatrixMultiplier::get);
|
||||
}
|
||||
14
test/cpp_extensions/setup.py
Normal file
14
test/cpp_extensions/setup.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from setuptools import setup, Extension
|
||||
import torch.utils.cpp_extension
|
||||
|
||||
ext_modules = [
|
||||
Extension(
|
||||
'torch_test_cpp_extensions', ['extension.cpp'],
|
||||
include_dirs=torch.utils.cpp_extension.include_paths(),
|
||||
language='c++'),
|
||||
]
|
||||
|
||||
setup(
|
||||
name='torch_test_cpp_extensions',
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension})
|
||||
|
|
@ -64,6 +64,16 @@ $PYCMD test_cuda.py $@
|
|||
echo "Running NCCL tests"
|
||||
$PYCMD test_nccl.py $@
|
||||
|
||||
echo "Running C++ Extensions tests"
|
||||
cd cpp_extensions
|
||||
$PYCMD setup.py install --root ./install
|
||||
previous_pythonpath="$PYTHONPATH"
|
||||
export PYTHONPATH="$PWD/$(find ./install -name site-packages):$PYTHONPATH"
|
||||
cd ..
|
||||
$PYCMD test_cpp_extensions.py $@
|
||||
export PYTHONPATH="$previous_pythonpath"
|
||||
rm -rf cpp_extensions/install
|
||||
|
||||
# Skipping test_distributed for Windows because it doesn't have fcntl
|
||||
if [[ "$OSTYPE" != "msys" ]]; then
|
||||
distributed_set_up() {
|
||||
|
|
|
|||
23
test/test_cpp_extensions.py
Normal file
23
test/test_cpp_extensions.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
import torch_test_cpp_extensions as cpp_extension
|
||||
|
||||
import common
|
||||
|
||||
|
||||
class TestCppExtension(common.TestCase):
|
||||
def test_extension_function(self):
|
||||
x = torch.randn(4, 4)
|
||||
y = torch.randn(4, 4)
|
||||
z = cpp_extension.sigmoid_add(x, y)
|
||||
self.assertEqual(z, x.sigmoid() + y.sigmoid())
|
||||
|
||||
def test_extension_module(self):
|
||||
mm = cpp_extension.MatrixMultiplier(4, 8)
|
||||
weights = torch.rand(8, 4)
|
||||
expected = mm.get().mm(weights)
|
||||
result = mm.forward(weights)
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common.run_tests()
|
||||
6
torch/torch.h
Normal file
6
torch/torch.h
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
18
torch/utils/cpp_extension.py
Normal file
18
torch/utils/cpp_extension.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
import os.path
|
||||
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
class BuildExtension(build_ext):
|
||||
"""A custom build extension for adding compiler-specific options."""
|
||||
|
||||
def build_extensions(self):
|
||||
for extension in self.extensions:
|
||||
extension.extra_compile_args = ['-std=c++11']
|
||||
build_ext.build_extensions(self)
|
||||
|
||||
|
||||
def include_paths():
|
||||
here = os.path.abspath(__file__)
|
||||
torch_path = os.path.dirname(os.path.dirname(here))
|
||||
return [os.path.join(torch_path, 'lib', 'include')]
|
||||
Loading…
Reference in New Issue
Block a user