mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Build mechanism for custom operators (#10226)
Summary:
This is the last step in the custom operator implementation: providing a way to build from C++ and Python. For this I:
1. Created a `FindTorch.cmake` taken largely from ebetica with a CMake function to easily create simple custom op libraries
2. Created a ` torch/op.h` header for easy inclusion of necessary headers,
3. Created a test directory `pytorch/test/custom_operator` which includes the basic setup for a custom op.
1. It defines an op in `op.{h,cpp}`
2. Registers it with the JIT using `RegisterOperators`
3. Builds it into a shared library via a `CMakeLists.txt`
4. Binds it into Python using a `setup.py`. This step makes use of our C++ extension setup that we already have. No work, yey!
The pure C++ and the Python builds are separate and not coupled in any way.
zdevito soumith dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10226
Differential Revision: D9296839
Pulled By: goldsborough
fbshipit-source-id: 32f74cafb6e3d86cada8dfca8136d0dfb1f197a0
This commit is contained in:
parent
67c6d93634
commit
c101a57a74
|
|
@ -230,7 +230,7 @@ if(NOT MSVC)
|
|||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Qunused-arguments")
|
||||
endif()
|
||||
if ((APPLE AND (NOT ("${CLANG_VERSION_STRING}" VERSION_LESS "9.0")))
|
||||
OR (CMAKE_COMPILER_IS_GNUCXX
|
||||
OR (CMAKE_COMPILER_IS_GNUCXX
|
||||
AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0 AND NOT APPLE)))
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
|
||||
endif()
|
||||
|
|
|
|||
55
cmake/TorchConfig.cmake.in
Normal file
55
cmake/TorchConfig.cmake.in
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# FindTorch
|
||||
# -------
|
||||
#
|
||||
# Finds the Torch library
|
||||
#
|
||||
# This will define the following variables:
|
||||
#
|
||||
# TORCH_FOUND -- True if the system has the Torch library
|
||||
# TORCH_INCLUDE_DIRS -- The include directories for torch
|
||||
# TORCH_LIBRARIES -- Libraries to link to
|
||||
#
|
||||
# and the following imported targets:
|
||||
#
|
||||
# Torch
|
||||
#
|
||||
# and the following functions:
|
||||
#
|
||||
# torch_add_custom_op_library(<name> <source_files>)
|
||||
|
||||
SET(TORCH_ROOT "${CMAKE_CURRENT_LIST_DIR}/../")
|
||||
|
||||
set(TORCH_INCLUDE_DIRS
|
||||
"${TORCH_ROOT}"
|
||||
"${TORCH_ROOT}/aten/src"
|
||||
"${CMAKE_CURRENT_LIST_DIR}/aten/src"
|
||||
"${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src"
|
||||
"${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src/TH"
|
||||
)
|
||||
|
||||
find_library(TORCH_LIBRARY torch PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
|
||||
find_library(CAFFE2_LIBRARY caffe2 PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
|
||||
|
||||
if (@USE_CUDA@)
|
||||
find_package(CUDA REQUIRED)
|
||||
find_library(CAFFE2_CUDA_LIBRARY caffe2_gpu PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
|
||||
set(TORCH_CUDA_LIBRARIES -L${CUDA_TOOLKIT_ROOT_DIR}/lib64 cuda nvrtc cudart nvToolsExt)
|
||||
list(APPEND TORCH_INCLUDE_DIRS ${CUDA_TOOLKIT_INCLUDE})
|
||||
endif()
|
||||
|
||||
set(TORCH_LIBRARIES
|
||||
${TORCH_LIBRARY}
|
||||
${CAFFE2_LIBRARY}
|
||||
${CAFFE2_CUDA_LIBRARY}
|
||||
${TORCH_CUDA_LIBRARIES})
|
||||
|
||||
# Creates a shared library <name> with the correct include directories
|
||||
# and linker flags set to include Torch header files and link with Torch
|
||||
# libraries. Also sets the C++ standard version to C++11. All options
|
||||
# can be override by specifying further options on the `<name>` CMake target.
|
||||
function(torch_add_custom_op_library name source_files)
|
||||
add_library(${name} SHARED ${source_files})
|
||||
target_include_directories(${name} PUBLIC "${TORCH_INCLUDE_DIRS}")
|
||||
target_link_libraries(${name} "${TORCH_LIBRARIES}")
|
||||
target_compile_options(${name} PUBLIC -std=c++11)
|
||||
endfunction(torch_add_custom_op_library)
|
||||
11
cmake/TorchConfigVersion.cmake.in
Normal file
11
cmake/TorchConfigVersion.cmake.in
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
set(PACKAGE_VERSION "@TORCH_VERSION@")
|
||||
|
||||
# Check whether the requested PACKAGE_FIND_VERSION is compatible
|
||||
if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}")
|
||||
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
||||
else()
|
||||
set(PACKAGE_VERSION_COMPATIBLE TRUE)
|
||||
if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}")
|
||||
set(PACKAGE_VERSION_EXACT TRUE)
|
||||
endif()
|
||||
endif()
|
||||
1
setup.py
1
setup.py
|
|
@ -415,6 +415,7 @@ class build_deps(PytorchCommand):
|
|||
self.copy_tree('third_party/pybind11/include/pybind11/',
|
||||
'torch/lib/include/pybind11')
|
||||
self.copy_file('torch/csrc/torch.h', 'torch/lib/include/torch/torch.h')
|
||||
self.copy_file('torch/op.h', 'torch/lib/include/torch/op.h')
|
||||
|
||||
|
||||
build_dep_cmds = {}
|
||||
|
|
|
|||
10
test/custom_operator/CMakeLists.txt
Normal file
10
test/custom_operator/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
# Basic CMake setup
|
||||
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
|
||||
project(custom_op)
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
torch_add_custom_op_library(custom_op op.cpp)
|
||||
|
||||
add_executable(custom_op_test test.cpp)
|
||||
target_link_libraries(custom_op_test custom_op)
|
||||
18
test/custom_operator/op.cpp
Normal file
18
test/custom_operator/op.cpp
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#include <torch/op.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> custom_op(
|
||||
at::Tensor tensor,
|
||||
double scalar,
|
||||
int64_t repeat) {
|
||||
std::vector<at::Tensor> output;
|
||||
output.reserve(repeat);
|
||||
for (int64_t i = 0; i < repeat; ++i) {
|
||||
output.push_back(tensor * scalar);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
static torch::RegisterOperators registry("custom::op", &custom_op);
|
||||
9
test/custom_operator/op.h
Normal file
9
test/custom_operator/op.h
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
#include <torch/op.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> custom_op(
|
||||
at::Tensor tensor,
|
||||
double scalar,
|
||||
int64_t repeat);
|
||||
25
test/custom_operator/test.cpp
Normal file
25
test/custom_operator/test.cpp
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#include "op.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
int main() {
|
||||
auto& ops = torch::jit::getAllOperatorsFor(
|
||||
torch::jit::Symbol::fromQualString("custom::op"));
|
||||
assert(ops.size() == 1);
|
||||
|
||||
auto& op = ops.front();
|
||||
assert(op->schema().name == "custom::op");
|
||||
|
||||
torch::jit::Stack stack;
|
||||
torch::jit::push(stack, torch::ones(5), 2.0, 3);
|
||||
op->getOperation()(stack);
|
||||
std::vector<at::Tensor> output;
|
||||
torch::jit::pop(stack, output);
|
||||
|
||||
assert(output.size() == 3);
|
||||
for (const auto& tensor : output) {
|
||||
assert(tensor.allclose(torch::ones(5) * 2));
|
||||
}
|
||||
std::cout << "success" << std::endl;
|
||||
}
|
||||
12
test/custom_operator/test.py
Normal file
12
test/custom_operator/test.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
library_path = os.path.abspath('build/libcustom_op.so')
|
||||
torch.ops.load_library(library_path)
|
||||
assert library_path in torch.ops.loaded_libraries
|
||||
|
||||
output = torch.ops.custom.op(torch.ones(5), 2.0, 3)
|
||||
assert type(output) == list
|
||||
assert len(output) == 3
|
||||
assert all(tensor.allclose(torch.ones(5) * 2) for tensor in output)
|
||||
print('success')
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
graph(%x : Dynamic) {
|
||||
%1 : Dynamic = ^aten::relu()(%x)
|
||||
return (%1);
|
||||
}
|
||||
|
|
@ -6266,7 +6266,7 @@ class TestJitGenerated(TestCase):
|
|||
pass
|
||||
|
||||
|
||||
class TestCustomOperators(TestCase):
|
||||
class TestCustomOperators(JitTestCase):
|
||||
|
||||
def test_dynamic_op_registry(self):
|
||||
from torch._ops import _OpNamespace
|
||||
|
|
@ -6337,19 +6337,30 @@ class TestCustomOperators(TestCase):
|
|||
"Unknown keyword argument 'foo' for operator 'aten::leaky_relu'"
|
||||
):
|
||||
torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5))
|
||||
#
|
||||
# def test_passing_and_returning_lists(self):
|
||||
# a, b = torch.ones(5), torch.zeros(5)
|
||||
# output = torch.ops.aten.stack([a, b])
|
||||
# self.assertEqual(output, torch.ones(10))
|
||||
#
|
||||
# def test_throws_for_tuples(self):
|
||||
# with self.assertRaisesRegex(
|
||||
# RuntimeError,
|
||||
# "Unknown keyword argument 'foo' for operator 'aten::leaky_relu'"
|
||||
# ):
|
||||
# torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5))
|
||||
|
||||
def test_passing_and_returning_lists(self):
|
||||
# Replace with actual test once we support lists.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Lists and tuples are not supported yet"
|
||||
):
|
||||
a, b = torch.ones(5), torch.zeros(5)
|
||||
output = torch.ops.aten.stack([a, b])
|
||||
self.assertEqual(output, torch.ones(10))
|
||||
|
||||
def test_passing_and_returning_tuples(self):
|
||||
# Replace with actual test once we support tuples.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Lists and tuples are not supported yet"
|
||||
):
|
||||
torch.ops.aten.max_pool2d(torch.ones(5, 5), [2, 2])
|
||||
|
||||
def test_script_graph_contains_custom_op(self):
|
||||
@torch.jit.script
|
||||
def func(x):
|
||||
return torch.ops.aten.relu(x)
|
||||
self.assertExpected(canonical(func.graph))
|
||||
|
||||
# UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
|
||||
# and we have to disable the failing tests here instead.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include "torch/csrc/jit/operator.h"
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
|
||||
#include "torch/csrc/autograd/profiler.h"
|
||||
#include "torch/csrc/jit/interned_strings.h"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,14 @@ endif()
|
|||
|
||||
option(BUILD_TORCH_TEST "Build torch test binaries" ON)
|
||||
|
||||
# TODO: Unify with version from setup.py
|
||||
set(TORCH_VERSION_MAJOR 0)
|
||||
set(TORCH_VERSION_MINOR 4)
|
||||
set(TORCH_VERSION_PATCH 1)
|
||||
set(TORCH_VERSION "${TORCH_VERSION_MAJOR}.${TORCH_VERSION_MINOR}.${TORCH_VERSION_PATCH}")
|
||||
|
||||
set(TORCH_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
set(TORCH_ROOT "${TORCH_SRC_DIR}/..")
|
||||
|
||||
add_subdirectory(../third_party/nanopb protobuf-nanopb)
|
||||
|
||||
|
|
@ -55,9 +62,9 @@ else()
|
|||
endif()
|
||||
|
||||
# Generate files
|
||||
set(TOOLS_PATH "${TORCH_SRC_DIR}/../tools")
|
||||
set(TOOLS_PATH "${TORCH_ROOT}/tools")
|
||||
|
||||
configure_file("${TORCH_SRC_DIR}/../aten/src/ATen/common_with_cwrap.py"
|
||||
configure_file("${TORCH_ROOT}/aten/src/ATen/common_with_cwrap.py"
|
||||
"${TOOLS_PATH}/shared/cwrap_common.py"
|
||||
COPYONLY)
|
||||
|
||||
|
|
@ -113,7 +120,7 @@ add_custom_command(
|
|||
"${TOOLS_PATH}/jit/gen_jit_dispatch.py"
|
||||
"${TOOLS_PATH}/jit/templates/register_aten_ops.cpp"
|
||||
"${TOOLS_PATH}/jit/templates/aten_interned_strings.h"
|
||||
WORKING_DIRECTORY "${TORCH_SRC_DIR}/..")
|
||||
WORKING_DIRECTORY "${TORCH_ROOT}")
|
||||
|
||||
set(TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
|
||||
|
|
@ -211,7 +218,7 @@ if(MSVC)
|
|||
else()
|
||||
set (MSVC_RUNTIME_LIBRARY_FLAG "/MD")
|
||||
endif()
|
||||
|
||||
|
||||
target_compile_options(torch PRIVATE
|
||||
${MSVC_RUNTIME_LIBRARY_OPTION}
|
||||
/Z7
|
||||
|
|
@ -339,9 +346,9 @@ endif()
|
|||
|
||||
set(TH_CPU_INCLUDE
|
||||
# dense
|
||||
${TORCH_SRC_DIR}/../aten/src/TH
|
||||
${TORCH_ROOT}/aten/src/TH
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/TH
|
||||
${TORCH_SRC_DIR}/../aten/src
|
||||
${TORCH_ROOT}/aten/src
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../aten/src
|
||||
${CMAKE_BINARY_DIR}/aten/src)
|
||||
target_include_directories(torch PRIVATE ${TH_CPU_INCLUDE})
|
||||
|
|
@ -349,13 +356,13 @@ target_include_directories(torch PRIVATE ${TH_CPU_INCLUDE})
|
|||
if(USE_CUDA OR USE_ROCM)
|
||||
set(TH_CUDA_INCLUDE
|
||||
# dense
|
||||
${TORCH_SRC_DIR}/../aten/src/THC
|
||||
${TORCH_ROOT}/aten/src/THC
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/THC)
|
||||
target_include_directories(torch PRIVATE ${TH_CUDA_INCLUDE})
|
||||
endif()
|
||||
|
||||
set(ATen_CPU_INCLUDE
|
||||
${TORCH_SRC_DIR}/../aten/src
|
||||
${TORCH_ROOT}/aten/src
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../aten/src
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/ATen
|
||||
${CMAKE_BINARY_DIR}/aten/src)
|
||||
|
|
@ -366,8 +373,8 @@ target_include_directories(torch PUBLIC
|
|||
|
||||
# SYSTEM headers are included with -isystem and thus do not trigger warnings.
|
||||
target_include_directories(torch SYSTEM PUBLIC
|
||||
"${TORCH_SRC_DIR}/../third_party/cereal/include" # For cereal/
|
||||
"${TORCH_SRC_DIR}/../third_party/nanopb")
|
||||
"${TORCH_ROOT}/third_party/cereal/include" # For cereal/
|
||||
"${TORCH_ROOT}/third_party/nanopb")
|
||||
|
||||
set_target_properties(torch PROPERTIES VERSION 1 SOVERSION 1)
|
||||
|
||||
|
|
@ -390,7 +397,7 @@ if (BUILD_TORCH_TEST AND NOT MSVC AND NOT APPLE AND NOT USE_ROCM)
|
|||
target_link_libraries(test_jit torch ${TORCH_CUDA_LIBRARIES})
|
||||
target_compile_definitions(test_jit PUBLIC USE_CATCH _FORCE_INLINES)
|
||||
target_include_directories(test_jit PUBLIC
|
||||
"${TORCH_SRC_DIR}/../third_party/catch/single_include"
|
||||
"${TORCH_ROOT}/third_party/catch/single_include"
|
||||
${ATen_CPU_INCLUDE})
|
||||
|
||||
if (USE_CUDA)
|
||||
|
|
@ -399,7 +406,7 @@ if (BUILD_TORCH_TEST AND NOT MSVC AND NOT APPLE AND NOT USE_ROCM)
|
|||
endif()
|
||||
|
||||
if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM)
|
||||
set(TORCH_API_TEST_DIR "${TORCH_SRC_DIR}/../test/cpp/api")
|
||||
set(TORCH_API_TEST_DIR "${TORCH_ROOT}/test/cpp/api")
|
||||
|
||||
add_executable(test_api
|
||||
${TORCH_API_TEST_DIR}/any.cpp
|
||||
|
|
@ -424,7 +431,7 @@ if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM)
|
|||
|
||||
target_include_directories(test_api
|
||||
PUBLIC
|
||||
"${TORCH_SRC_DIR}/../third_party/catch/single_include"
|
||||
"${TORCH_ROOT}/third_party/catch/single_include"
|
||||
${ATen_CPU_INCLUDE})
|
||||
|
||||
target_link_libraries(test_api torch ${TORCH_CUDA_LIBRARIES})
|
||||
|
|
@ -445,3 +452,13 @@ if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM)
|
|||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# CMake config for external projects.
|
||||
configure_file(
|
||||
${PROJECT_SOURCE_DIR}/cmake/TorchConfigVersion.cmake.in
|
||||
${PROJECT_BINARY_DIR}/TorchConfigVersion.cmake
|
||||
@ONLY)
|
||||
configure_file(
|
||||
${TORCH_ROOT}/cmake/TorchConfig.cmake.in
|
||||
${PROJECT_BINARY_DIR}/TorchConfig.cmake
|
||||
@ONLY)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,27 @@
|
|||
import torch._C
|
||||
|
||||
import contextlib
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
|
||||
# Query `hasattr` only once.
|
||||
_SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def dl_open_guard():
|
||||
"""
|
||||
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
|
||||
shared library to load custom operators.
|
||||
"""
|
||||
if _SET_GLOBAL_FLAGS:
|
||||
old_flags = sys.getdlopenflags()
|
||||
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
|
||||
yield
|
||||
if _SET_GLOBAL_FLAGS:
|
||||
sys.setdlopenflags(old_flags)
|
||||
|
||||
|
||||
class _OpNamespace(object):
|
||||
"""
|
||||
|
|
@ -33,12 +55,40 @@ class _OpNamespace(object):
|
|||
|
||||
|
||||
class _Ops(object):
|
||||
def __init__(self):
|
||||
self.loaded_libraries = set()
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Here we are creating `torch.ops.my_namespace`
|
||||
namespace = _OpNamespace(name)
|
||||
setattr(self, name, namespace)
|
||||
return namespace
|
||||
|
||||
def load_library(self, path):
|
||||
"""
|
||||
Loads a shared library from the given path into the current process.
|
||||
|
||||
The library being loaded may run global initialization code to register
|
||||
custom operators with the PyTorch JIT runtime. This allows dynamically
|
||||
loading custom operators. For this, you should compile your operator
|
||||
and the static registration code into a shared library object, and then
|
||||
call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
|
||||
shared object.
|
||||
|
||||
After the library is loaded, it is added to the
|
||||
``torch.ops.loaded_libraries`` attribute, a set that may be inspected
|
||||
for the paths of all libraries loaded using this function.
|
||||
|
||||
Arguments:
|
||||
path (str): A path to a shared library to load.
|
||||
"""
|
||||
with dl_open_guard():
|
||||
# Import the shared library into the process, thus running its
|
||||
# static (global) initialization code in order to register custom
|
||||
# operators with the JIT.
|
||||
ctypes.CDLL(path)
|
||||
self.loaded_libraries.add(path)
|
||||
|
||||
|
||||
# The ops "namespace"
|
||||
ops = _Ops()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include "torch/csrc/jit/constants.h"
|
||||
#include "torch/csrc/jit/operator.h"
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/function_schema.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/stack.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
|
|
@ -79,13 +78,10 @@ Node* getTracedNode(
|
|||
/// Does two things for an operator implementation and a tuple of arguments:
|
||||
/// 1. Pops all necessary arguments off the stack into the tuple's elements,
|
||||
/// 2. Unpacks the tuple and calls the operator implementation.
|
||||
/// The result of the implementation call is returned.
|
||||
template <
|
||||
typename ReturnType,
|
||||
typename Implementation,
|
||||
typename... Types,
|
||||
size_t... Is>
|
||||
ReturnType callOperatorWithTuple(
|
||||
/// If tracing is currently enabled, this function will also take care of
|
||||
/// tracing the operator call.
|
||||
template <typename Implementation, typename... Types, size_t... Is>
|
||||
void callOperatorWithTuple(
|
||||
const FunctionSchema& schema,
|
||||
Implementation&& implementation,
|
||||
Stack& stack,
|
||||
|
|
@ -104,10 +100,10 @@ ReturnType callOperatorWithTuple(
|
|||
jit::tracer::postRecordTrace(node, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
push(stack, IValue(std::move(result)));
|
||||
}
|
||||
|
||||
void checkArgumentVector(
|
||||
inline void checkArgumentVector(
|
||||
const char* what,
|
||||
const std::vector<Argument>& inferred,
|
||||
const std::vector<Argument>& provided,
|
||||
|
|
@ -204,21 +200,54 @@ Operator createOperator(
|
|||
c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
|
||||
using ArgumentTuple =
|
||||
typename c10::guts::typelist::to_tuple<ArgumentTypes>::type;
|
||||
using ReturnType = decay_t<typename Traits::return_type>;
|
||||
|
||||
auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
|
||||
|
||||
return Operator(schema, [implementation, schema](Stack& stack) {
|
||||
ArgumentTuple tuple;
|
||||
auto result = torch::jit::detail::callOperatorWithTuple<ReturnType>(
|
||||
torch::jit::detail::callOperatorWithTuple(
|
||||
schema,
|
||||
std::move(implementation),
|
||||
stack,
|
||||
tuple,
|
||||
typename MakeIndices<std::tuple_size<ArgumentTuple>::value>::indices{});
|
||||
pack(stack, std::move(result));
|
||||
return 0;
|
||||
});
|
||||
}
|
||||
|
||||
/// Registration class for new operators. Effectively calls
|
||||
/// `torch::jit::registerOperator` for every supplied operator, but allows doing
|
||||
/// so in the global scope when a `RegisterOperators` object is assigned to a
|
||||
/// static variable. Also handles registration of user-defined, "custom"
|
||||
/// operators.
|
||||
struct TORCH_API RegisterOperators {
|
||||
RegisterOperators() = default;
|
||||
|
||||
/// Registers a vector of already created `Operator`s.
|
||||
RegisterOperators(std::vector<Operator> operators) {
|
||||
for (Operator& o : operators) {
|
||||
registerOperator(std::move(o));
|
||||
}
|
||||
}
|
||||
|
||||
/// Calls `op(...)` with the given operator name and implementation.
|
||||
template <typename Implementation>
|
||||
RegisterOperators(const std::string& name, Implementation&& implementation) {
|
||||
op(name, std::forward<Implementation>(implementation));
|
||||
}
|
||||
|
||||
/// Creates a new operator from a name and implementation function (function
|
||||
/// pointer or function object/lambda) using `torch::jit::createOperator`, and
|
||||
/// then registers the operator.
|
||||
template <typename Implementation>
|
||||
RegisterOperators& op(
|
||||
const std::string& name,
|
||||
Implementation&& implementation) {
|
||||
registerOperator(
|
||||
createOperator(name, std::forward<Implementation>(implementation)));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -33,6 +33,13 @@
|
|||
|
||||
#include <pybind11/functional.h>
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
namespace {
|
||||
|
|
@ -232,10 +239,14 @@ void initJITBindings(PyObject *module) {
|
|||
"Found ", operations.size(), " overloads for operator ",
|
||||
qualified_name, "! Overloads are not supported from Python.");
|
||||
std::shared_ptr<Operator> op = operations[0];
|
||||
AT_ASSERT(op != nullptr);
|
||||
std::ostringstream docstring;
|
||||
docstring << "Automatically bound operator '" << qualified_name
|
||||
<< "' with schema: " << op->schema();
|
||||
return py::cpp_function([op](py::args args, py::kwargs kwargs) {
|
||||
return invokeOperatorFromPython(
|
||||
*op, std::move(args), std::move(kwargs));
|
||||
});
|
||||
}, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str()));
|
||||
} catch (const at::Error& error) {
|
||||
throw std::runtime_error(error.what_without_backtrace());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -94,14 +94,6 @@ void registerOperator(Operator&& op);
|
|||
// XXX: this function is meant to be used with string literals only!
|
||||
Operator& sig(const char *signature_literal);
|
||||
|
||||
struct TORCH_API RegisterOperators {
|
||||
RegisterOperators(std::vector<Operator> operators) {
|
||||
for(Operator& o : operators) {
|
||||
registerOperator(std::move(o));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct OperatorSet {
|
||||
OperatorSet(std::initializer_list<const char *> sig_literals);
|
||||
// XXX: Returns a nullptr if no Operator in the set matches n
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/jit/fusion_compiler.h"
|
||||
#include "torch/csrc/jit/operator.h"
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
#include "torch/csrc/jit/graph_executor.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "torch/csrc/jit/graph_executor.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/jit/operator.h"
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
|
||||
#include "torch/csrc/variable_tensor_functions.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -1058,6 +1058,34 @@ void testCustomOperators() {
|
|||
REQUIRE(output[0] == 1.0);
|
||||
REQUIRE(output[1] == 2.0);
|
||||
}
|
||||
{
|
||||
RegisterOperators reg(
|
||||
"foo::lists2(Tensor[] tensors) -> Tensor[]",
|
||||
[](std::vector<at::Tensor> tensors) { return tensors; });
|
||||
|
||||
auto& ops =
|
||||
getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
|
||||
REQUIRE(ops.size() == 1);
|
||||
|
||||
auto& op = ops.front();
|
||||
REQUIRE(op->schema().name == "foo::lists2");
|
||||
|
||||
REQUIRE(op->schema().arguments.size() == 1);
|
||||
REQUIRE(op->schema().arguments[0].name == "tensors");
|
||||
REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofTensors()));
|
||||
|
||||
REQUIRE(op->schema().returns.size() == 1);
|
||||
REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofTensors()));
|
||||
|
||||
Stack stack;
|
||||
push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
|
||||
op->getOperation()(stack);
|
||||
std::vector<at::Tensor> output;
|
||||
pop(stack, output);
|
||||
|
||||
REQUIRE(output.size() == 1);
|
||||
REQUIRE(output[0].allclose(autograd::make_variable(at::ones(5))));
|
||||
}
|
||||
{
|
||||
#ifdef USE_CATCH
|
||||
REQUIRE_THROWS_WITH(
|
||||
|
|
|
|||
11
torch/op.h
Normal file
11
torch/op.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace torch {
|
||||
using jit::createOperator;
|
||||
using jit::RegisterOperators;
|
||||
} // namespace torch
|
||||
Loading…
Reference in New Issue
Block a user