Build mechanism for custom operators (#10226)

Summary:
This is the last step in the custom operator implementation: providing a way to build from C++ and Python. For this I:

1. Created a `FindTorch.cmake` taken largely from ebetica with a CMake function to easily create simple custom op libraries
2. Created a ` torch/op.h` header for easy inclusion of necessary headers,
3. Created a test directory `pytorch/test/custom_operator` which includes the basic setup for a custom op.
    1. It defines an op in `op.{h,cpp}`
    2. Registers it with the JIT using `RegisterOperators`
    3. Builds it into a shared library via a `CMakeLists.txt`
    4. Binds it into Python using a `setup.py`. This step makes use of our C++ extension setup that we already have. No work, yey!

The pure C++ and the Python builds are separate and not coupled in any way.

zdevito soumith dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10226

Differential Revision: D9296839

Pulled By: goldsborough

fbshipit-source-id: 32f74cafb6e3d86cada8dfca8136d0dfb1f197a0
This commit is contained in:
Peter Goldsborough 2018-08-16 18:45:28 -07:00 committed by Facebook Github Bot
parent 67c6d93634
commit c101a57a74
22 changed files with 347 additions and 49 deletions

View File

@ -230,7 +230,7 @@ if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Qunused-arguments") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Qunused-arguments")
endif() endif()
if ((APPLE AND (NOT ("${CLANG_VERSION_STRING}" VERSION_LESS "9.0"))) if ((APPLE AND (NOT ("${CLANG_VERSION_STRING}" VERSION_LESS "9.0")))
OR (CMAKE_COMPILER_IS_GNUCXX OR (CMAKE_COMPILER_IS_GNUCXX
AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0 AND NOT APPLE))) AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0 AND NOT APPLE)))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
endif() endif()

View File

@ -0,0 +1,55 @@
# FindTorch
# -------
#
# Finds the Torch library
#
# This will define the following variables:
#
# TORCH_FOUND -- True if the system has the Torch library
# TORCH_INCLUDE_DIRS -- The include directories for torch
# TORCH_LIBRARIES -- Libraries to link to
#
# and the following imported targets:
#
# Torch
#
# and the following functions:
#
# torch_add_custom_op_library(<name> <source_files>)
SET(TORCH_ROOT "${CMAKE_CURRENT_LIST_DIR}/../")
set(TORCH_INCLUDE_DIRS
"${TORCH_ROOT}"
"${TORCH_ROOT}/aten/src"
"${CMAKE_CURRENT_LIST_DIR}/aten/src"
"${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src"
"${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src/TH"
)
find_library(TORCH_LIBRARY torch PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
find_library(CAFFE2_LIBRARY caffe2 PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
if (@USE_CUDA@)
find_package(CUDA REQUIRED)
find_library(CAFFE2_CUDA_LIBRARY caffe2_gpu PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
set(TORCH_CUDA_LIBRARIES -L${CUDA_TOOLKIT_ROOT_DIR}/lib64 cuda nvrtc cudart nvToolsExt)
list(APPEND TORCH_INCLUDE_DIRS ${CUDA_TOOLKIT_INCLUDE})
endif()
set(TORCH_LIBRARIES
${TORCH_LIBRARY}
${CAFFE2_LIBRARY}
${CAFFE2_CUDA_LIBRARY}
${TORCH_CUDA_LIBRARIES})
# Creates a shared library <name> with the correct include directories
# and linker flags set to include Torch header files and link with Torch
# libraries. Also sets the C++ standard version to C++11. All options
# can be override by specifying further options on the `<name>` CMake target.
function(torch_add_custom_op_library name source_files)
add_library(${name} SHARED ${source_files})
target_include_directories(${name} PUBLIC "${TORCH_INCLUDE_DIRS}")
target_link_libraries(${name} "${TORCH_LIBRARIES}")
target_compile_options(${name} PUBLIC -std=c++11)
endfunction(torch_add_custom_op_library)

View File

@ -0,0 +1,11 @@
set(PACKAGE_VERSION "@TORCH_VERSION@")
# Check whether the requested PACKAGE_FIND_VERSION is compatible
if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}")
set(PACKAGE_VERSION_COMPATIBLE FALSE)
else()
set(PACKAGE_VERSION_COMPATIBLE TRUE)
if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}")
set(PACKAGE_VERSION_EXACT TRUE)
endif()
endif()

View File

@ -415,6 +415,7 @@ class build_deps(PytorchCommand):
self.copy_tree('third_party/pybind11/include/pybind11/', self.copy_tree('third_party/pybind11/include/pybind11/',
'torch/lib/include/pybind11') 'torch/lib/include/pybind11')
self.copy_file('torch/csrc/torch.h', 'torch/lib/include/torch/torch.h') self.copy_file('torch/csrc/torch.h', 'torch/lib/include/torch/torch.h')
self.copy_file('torch/op.h', 'torch/lib/include/torch/op.h')
build_dep_cmds = {} build_dep_cmds = {}

View File

@ -0,0 +1,10 @@
# Basic CMake setup
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_op)
find_package(Torch REQUIRED)
torch_add_custom_op_library(custom_op op.cpp)
add_executable(custom_op_test test.cpp)
target_link_libraries(custom_op_test custom_op)

View File

@ -0,0 +1,18 @@
#include <torch/op.h>
#include <cstddef>
#include <vector>
std::vector<at::Tensor> custom_op(
at::Tensor tensor,
double scalar,
int64_t repeat) {
std::vector<at::Tensor> output;
output.reserve(repeat);
for (int64_t i = 0; i < repeat; ++i) {
output.push_back(tensor * scalar);
}
return output;
}
static torch::RegisterOperators registry("custom::op", &custom_op);

View File

@ -0,0 +1,9 @@
#include <torch/op.h>
#include <cstddef>
#include <vector>
std::vector<at::Tensor> custom_op(
at::Tensor tensor,
double scalar,
int64_t repeat);

View File

@ -0,0 +1,25 @@
#include "op.h"
#include <cassert>
#include <vector>
int main() {
auto& ops = torch::jit::getAllOperatorsFor(
torch::jit::Symbol::fromQualString("custom::op"));
assert(ops.size() == 1);
auto& op = ops.front();
assert(op->schema().name == "custom::op");
torch::jit::Stack stack;
torch::jit::push(stack, torch::ones(5), 2.0, 3);
op->getOperation()(stack);
std::vector<at::Tensor> output;
torch::jit::pop(stack, output);
assert(output.size() == 3);
for (const auto& tensor : output) {
assert(tensor.allclose(torch::ones(5) * 2));
}
std::cout << "success" << std::endl;
}

View File

@ -0,0 +1,12 @@
import os
import torch
library_path = os.path.abspath('build/libcustom_op.so')
torch.ops.load_library(library_path)
assert library_path in torch.ops.loaded_libraries
output = torch.ops.custom.op(torch.ones(5), 2.0, 3)
assert type(output) == list
assert len(output) == 3
assert all(tensor.allclose(torch.ones(5) * 2) for tensor in output)
print('success')

View File

@ -0,0 +1,4 @@
graph(%x : Dynamic) {
%1 : Dynamic = ^aten::relu()(%x)
return (%1);
}

View File

@ -6266,7 +6266,7 @@ class TestJitGenerated(TestCase):
pass pass
class TestCustomOperators(TestCase): class TestCustomOperators(JitTestCase):
def test_dynamic_op_registry(self): def test_dynamic_op_registry(self):
from torch._ops import _OpNamespace from torch._ops import _OpNamespace
@ -6337,19 +6337,30 @@ class TestCustomOperators(TestCase):
"Unknown keyword argument 'foo' for operator 'aten::leaky_relu'" "Unknown keyword argument 'foo' for operator 'aten::leaky_relu'"
): ):
torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5)) torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5))
#
# def test_passing_and_returning_lists(self):
# a, b = torch.ones(5), torch.zeros(5)
# output = torch.ops.aten.stack([a, b])
# self.assertEqual(output, torch.ones(10))
#
# def test_throws_for_tuples(self):
# with self.assertRaisesRegex(
# RuntimeError,
# "Unknown keyword argument 'foo' for operator 'aten::leaky_relu'"
# ):
# torch.ops.aten.leaky_relu(torch.ones(5), foo=torch.ones(5))
def test_passing_and_returning_lists(self):
# Replace with actual test once we support lists.
with self.assertRaisesRegex(
RuntimeError,
"Lists and tuples are not supported yet"
):
a, b = torch.ones(5), torch.zeros(5)
output = torch.ops.aten.stack([a, b])
self.assertEqual(output, torch.ones(10))
def test_passing_and_returning_tuples(self):
# Replace with actual test once we support tuples.
with self.assertRaisesRegex(
RuntimeError,
"Lists and tuples are not supported yet"
):
torch.ops.aten.max_pool2d(torch.ones(5, 5), [2, 2])
def test_script_graph_contains_custom_op(self):
@torch.jit.script
def func(x):
return torch.ops.aten.relu(x)
self.assertExpected(canonical(func.graph))
# UBSAN per-function exclusions don't seem to work with OpenMP pragmas, # UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
# and we have to disable the failing tests here instead. # and we have to disable the failing tests here instead.

View File

@ -1,4 +1,5 @@
#include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/autograd/profiler.h" #include "torch/csrc/autograd/profiler.h"
#include "torch/csrc/jit/interned_strings.h" #include "torch/csrc/jit/interned_strings.h"

View File

@ -11,7 +11,14 @@ endif()
option(BUILD_TORCH_TEST "Build torch test binaries" ON) option(BUILD_TORCH_TEST "Build torch test binaries" ON)
# TODO: Unify with version from setup.py
set(TORCH_VERSION_MAJOR 0)
set(TORCH_VERSION_MINOR 4)
set(TORCH_VERSION_PATCH 1)
set(TORCH_VERSION "${TORCH_VERSION_MAJOR}.${TORCH_VERSION_MINOR}.${TORCH_VERSION_PATCH}")
set(TORCH_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TORCH_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TORCH_ROOT "${TORCH_SRC_DIR}/..")
add_subdirectory(../third_party/nanopb protobuf-nanopb) add_subdirectory(../third_party/nanopb protobuf-nanopb)
@ -55,9 +62,9 @@ else()
endif() endif()
# Generate files # Generate files
set(TOOLS_PATH "${TORCH_SRC_DIR}/../tools") set(TOOLS_PATH "${TORCH_ROOT}/tools")
configure_file("${TORCH_SRC_DIR}/../aten/src/ATen/common_with_cwrap.py" configure_file("${TORCH_ROOT}/aten/src/ATen/common_with_cwrap.py"
"${TOOLS_PATH}/shared/cwrap_common.py" "${TOOLS_PATH}/shared/cwrap_common.py"
COPYONLY) COPYONLY)
@ -113,7 +120,7 @@ add_custom_command(
"${TOOLS_PATH}/jit/gen_jit_dispatch.py" "${TOOLS_PATH}/jit/gen_jit_dispatch.py"
"${TOOLS_PATH}/jit/templates/register_aten_ops.cpp" "${TOOLS_PATH}/jit/templates/register_aten_ops.cpp"
"${TOOLS_PATH}/jit/templates/aten_interned_strings.h" "${TOOLS_PATH}/jit/templates/aten_interned_strings.h"
WORKING_DIRECTORY "${TORCH_SRC_DIR}/..") WORKING_DIRECTORY "${TORCH_ROOT}")
set(TORCH_SRCS set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp ${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
@ -211,7 +218,7 @@ if(MSVC)
else() else()
set (MSVC_RUNTIME_LIBRARY_FLAG "/MD") set (MSVC_RUNTIME_LIBRARY_FLAG "/MD")
endif() endif()
target_compile_options(torch PRIVATE target_compile_options(torch PRIVATE
${MSVC_RUNTIME_LIBRARY_OPTION} ${MSVC_RUNTIME_LIBRARY_OPTION}
/Z7 /Z7
@ -339,9 +346,9 @@ endif()
set(TH_CPU_INCLUDE set(TH_CPU_INCLUDE
# dense # dense
${TORCH_SRC_DIR}/../aten/src/TH ${TORCH_ROOT}/aten/src/TH
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/TH ${CMAKE_CURRENT_BINARY_DIR}/../aten/src/TH
${TORCH_SRC_DIR}/../aten/src ${TORCH_ROOT}/aten/src
${CMAKE_CURRENT_BINARY_DIR}/../aten/src ${CMAKE_CURRENT_BINARY_DIR}/../aten/src
${CMAKE_BINARY_DIR}/aten/src) ${CMAKE_BINARY_DIR}/aten/src)
target_include_directories(torch PRIVATE ${TH_CPU_INCLUDE}) target_include_directories(torch PRIVATE ${TH_CPU_INCLUDE})
@ -349,13 +356,13 @@ target_include_directories(torch PRIVATE ${TH_CPU_INCLUDE})
if(USE_CUDA OR USE_ROCM) if(USE_CUDA OR USE_ROCM)
set(TH_CUDA_INCLUDE set(TH_CUDA_INCLUDE
# dense # dense
${TORCH_SRC_DIR}/../aten/src/THC ${TORCH_ROOT}/aten/src/THC
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/THC) ${CMAKE_CURRENT_BINARY_DIR}/../aten/src/THC)
target_include_directories(torch PRIVATE ${TH_CUDA_INCLUDE}) target_include_directories(torch PRIVATE ${TH_CUDA_INCLUDE})
endif() endif()
set(ATen_CPU_INCLUDE set(ATen_CPU_INCLUDE
${TORCH_SRC_DIR}/../aten/src ${TORCH_ROOT}/aten/src
${CMAKE_CURRENT_BINARY_DIR}/../aten/src ${CMAKE_CURRENT_BINARY_DIR}/../aten/src
${CMAKE_CURRENT_BINARY_DIR}/../aten/src/ATen ${CMAKE_CURRENT_BINARY_DIR}/../aten/src/ATen
${CMAKE_BINARY_DIR}/aten/src) ${CMAKE_BINARY_DIR}/aten/src)
@ -366,8 +373,8 @@ target_include_directories(torch PUBLIC
# SYSTEM headers are included with -isystem and thus do not trigger warnings. # SYSTEM headers are included with -isystem and thus do not trigger warnings.
target_include_directories(torch SYSTEM PUBLIC target_include_directories(torch SYSTEM PUBLIC
"${TORCH_SRC_DIR}/../third_party/cereal/include" # For cereal/ "${TORCH_ROOT}/third_party/cereal/include" # For cereal/
"${TORCH_SRC_DIR}/../third_party/nanopb") "${TORCH_ROOT}/third_party/nanopb")
set_target_properties(torch PROPERTIES VERSION 1 SOVERSION 1) set_target_properties(torch PROPERTIES VERSION 1 SOVERSION 1)
@ -390,7 +397,7 @@ if (BUILD_TORCH_TEST AND NOT MSVC AND NOT APPLE AND NOT USE_ROCM)
target_link_libraries(test_jit torch ${TORCH_CUDA_LIBRARIES}) target_link_libraries(test_jit torch ${TORCH_CUDA_LIBRARIES})
target_compile_definitions(test_jit PUBLIC USE_CATCH _FORCE_INLINES) target_compile_definitions(test_jit PUBLIC USE_CATCH _FORCE_INLINES)
target_include_directories(test_jit PUBLIC target_include_directories(test_jit PUBLIC
"${TORCH_SRC_DIR}/../third_party/catch/single_include" "${TORCH_ROOT}/third_party/catch/single_include"
${ATen_CPU_INCLUDE}) ${ATen_CPU_INCLUDE})
if (USE_CUDA) if (USE_CUDA)
@ -399,7 +406,7 @@ if (BUILD_TORCH_TEST AND NOT MSVC AND NOT APPLE AND NOT USE_ROCM)
endif() endif()
if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM) if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM)
set(TORCH_API_TEST_DIR "${TORCH_SRC_DIR}/../test/cpp/api") set(TORCH_API_TEST_DIR "${TORCH_ROOT}/test/cpp/api")
add_executable(test_api add_executable(test_api
${TORCH_API_TEST_DIR}/any.cpp ${TORCH_API_TEST_DIR}/any.cpp
@ -424,7 +431,7 @@ if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM)
target_include_directories(test_api target_include_directories(test_api
PUBLIC PUBLIC
"${TORCH_SRC_DIR}/../third_party/catch/single_include" "${TORCH_ROOT}/third_party/catch/single_include"
${ATen_CPU_INCLUDE}) ${ATen_CPU_INCLUDE})
target_link_libraries(test_api torch ${TORCH_CUDA_LIBRARIES}) target_link_libraries(test_api torch ${TORCH_CUDA_LIBRARIES})
@ -445,3 +452,13 @@ if (BUILD_TORCH_TEST AND NOT NO_API AND NOT USE_ROCM)
endif() endif()
endif() endif()
endif() endif()
# CMake config for external projects.
configure_file(
${PROJECT_SOURCE_DIR}/cmake/TorchConfigVersion.cmake.in
${PROJECT_BINARY_DIR}/TorchConfigVersion.cmake
@ONLY)
configure_file(
${TORCH_ROOT}/cmake/TorchConfig.cmake.in
${PROJECT_BINARY_DIR}/TorchConfig.cmake
@ONLY)

View File

@ -1,5 +1,27 @@
import torch._C import torch._C
import contextlib
import ctypes
import sys
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')
@contextlib.contextmanager
def dl_open_guard():
"""
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
shared library to load custom operators.
"""
if _SET_GLOBAL_FLAGS:
old_flags = sys.getdlopenflags()
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
yield
if _SET_GLOBAL_FLAGS:
sys.setdlopenflags(old_flags)
class _OpNamespace(object): class _OpNamespace(object):
""" """
@ -33,12 +55,40 @@ class _OpNamespace(object):
class _Ops(object): class _Ops(object):
def __init__(self):
self.loaded_libraries = set()
def __getattr__(self, name): def __getattr__(self, name):
# Here we are creating `torch.ops.my_namespace` # Here we are creating `torch.ops.my_namespace`
namespace = _OpNamespace(name) namespace = _OpNamespace(name)
setattr(self, name, namespace) setattr(self, name, namespace)
return namespace return namespace
def load_library(self, path):
"""
Loads a shared library from the given path into the current process.
The library being loaded may run global initialization code to register
custom operators with the PyTorch JIT runtime. This allows dynamically
loading custom operators. For this, you should compile your operator
and the static registration code into a shared library object, and then
call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
shared object.
After the library is loaded, it is added to the
``torch.ops.loaded_libraries`` attribute, a set that may be inspected
for the paths of all libraries loaded using this function.
Arguments:
path (str): A path to a shared library to load.
"""
with dl_open_guard():
# Import the shared library into the process, thus running its
# static (global) initialization code in order to register custom
# operators with the JIT.
ctypes.CDLL(path)
self.loaded_libraries.add(path)
# The ops "namespace" # The ops "namespace"
ops = _Ops() ops = _Ops()

View File

@ -1,5 +1,6 @@
#include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/constants.h"
#include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/autograd/variable.h" #include "torch/csrc/autograd/variable.h"
namespace torch { namespace jit { namespace torch { namespace jit {

View File

@ -1,7 +1,6 @@
#pragma once #pragma once
#include <torch/csrc/jit/function_schema.h> #include <torch/csrc/jit/function_schema.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h> #include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/stack.h> #include <torch/csrc/jit/stack.h>
#include <torch/csrc/jit/tracer.h> #include <torch/csrc/jit/tracer.h>
@ -79,13 +78,10 @@ Node* getTracedNode(
/// Does two things for an operator implementation and a tuple of arguments: /// Does two things for an operator implementation and a tuple of arguments:
/// 1. Pops all necessary arguments off the stack into the tuple's elements, /// 1. Pops all necessary arguments off the stack into the tuple's elements,
/// 2. Unpacks the tuple and calls the operator implementation. /// 2. Unpacks the tuple and calls the operator implementation.
/// The result of the implementation call is returned. /// If tracing is currently enabled, this function will also take care of
template < /// tracing the operator call.
typename ReturnType, template <typename Implementation, typename... Types, size_t... Is>
typename Implementation, void callOperatorWithTuple(
typename... Types,
size_t... Is>
ReturnType callOperatorWithTuple(
const FunctionSchema& schema, const FunctionSchema& schema,
Implementation&& implementation, Implementation&& implementation,
Stack& stack, Stack& stack,
@ -104,10 +100,10 @@ ReturnType callOperatorWithTuple(
jit::tracer::postRecordTrace(node, result); jit::tracer::postRecordTrace(node, result);
} }
return result; push(stack, IValue(std::move(result)));
} }
void checkArgumentVector( inline void checkArgumentVector(
const char* what, const char* what,
const std::vector<Argument>& inferred, const std::vector<Argument>& inferred,
const std::vector<Argument>& provided, const std::vector<Argument>& provided,
@ -204,21 +200,54 @@ Operator createOperator(
c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>; c10::guts::typelist::map_t<decay_t, typename Traits::parameter_types>;
using ArgumentTuple = using ArgumentTuple =
typename c10::guts::typelist::to_tuple<ArgumentTypes>::type; typename c10::guts::typelist::to_tuple<ArgumentTypes>::type;
using ReturnType = decay_t<typename Traits::return_type>;
auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName); auto schema = torch::jit::detail::inferAndCheckSchema<Traits>(schemaOrName);
return Operator(schema, [implementation, schema](Stack& stack) { return Operator(schema, [implementation, schema](Stack& stack) {
ArgumentTuple tuple; ArgumentTuple tuple;
auto result = torch::jit::detail::callOperatorWithTuple<ReturnType>( torch::jit::detail::callOperatorWithTuple(
schema, schema,
std::move(implementation), std::move(implementation),
stack, stack,
tuple, tuple,
typename MakeIndices<std::tuple_size<ArgumentTuple>::value>::indices{}); typename MakeIndices<std::tuple_size<ArgumentTuple>::value>::indices{});
pack(stack, std::move(result));
return 0; return 0;
}); });
} }
/// Registration class for new operators. Effectively calls
/// `torch::jit::registerOperator` for every supplied operator, but allows doing
/// so in the global scope when a `RegisterOperators` object is assigned to a
/// static variable. Also handles registration of user-defined, "custom"
/// operators.
struct TORCH_API RegisterOperators {
RegisterOperators() = default;
/// Registers a vector of already created `Operator`s.
RegisterOperators(std::vector<Operator> operators) {
for (Operator& o : operators) {
registerOperator(std::move(o));
}
}
/// Calls `op(...)` with the given operator name and implementation.
template <typename Implementation>
RegisterOperators(const std::string& name, Implementation&& implementation) {
op(name, std::forward<Implementation>(implementation));
}
/// Creates a new operator from a name and implementation function (function
/// pointer or function object/lambda) using `torch::jit::createOperator`, and
/// then registers the operator.
template <typename Implementation>
RegisterOperators& op(
const std::string& name,
Implementation&& implementation) {
registerOperator(
createOperator(name, std::forward<Implementation>(implementation)));
return *this;
}
};
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -33,6 +33,13 @@
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
namespace torch { namespace jit { namespace torch { namespace jit {
namespace { namespace {
@ -232,10 +239,14 @@ void initJITBindings(PyObject *module) {
"Found ", operations.size(), " overloads for operator ", "Found ", operations.size(), " overloads for operator ",
qualified_name, "! Overloads are not supported from Python."); qualified_name, "! Overloads are not supported from Python.");
std::shared_ptr<Operator> op = operations[0]; std::shared_ptr<Operator> op = operations[0];
AT_ASSERT(op != nullptr);
std::ostringstream docstring;
docstring << "Automatically bound operator '" << qualified_name
<< "' with schema: " << op->schema();
return py::cpp_function([op](py::args args, py::kwargs kwargs) { return py::cpp_function([op](py::args args, py::kwargs kwargs) {
return invokeOperatorFromPython( return invokeOperatorFromPython(
*op, std::move(args), std::move(kwargs)); *op, std::move(args), std::move(kwargs));
}); }, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str()));
} catch (const at::Error& error) { } catch (const at::Error& error) {
throw std::runtime_error(error.what_without_backtrace()); throw std::runtime_error(error.what_without_backtrace());
} }

View File

@ -94,14 +94,6 @@ void registerOperator(Operator&& op);
// XXX: this function is meant to be used with string literals only! // XXX: this function is meant to be used with string literals only!
Operator& sig(const char *signature_literal); Operator& sig(const char *signature_literal);
struct TORCH_API RegisterOperators {
RegisterOperators(std::vector<Operator> operators) {
for(Operator& o : operators) {
registerOperator(std::move(o));
}
}
};
struct OperatorSet { struct OperatorSet {
OperatorSet(std::initializer_list<const char *> sig_literals); OperatorSet(std::initializer_list<const char *> sig_literals);
// XXX: Returns a nullptr if no Operator in the set matches n // XXX: Returns a nullptr if no Operator in the set matches n

View File

@ -7,6 +7,7 @@
#include "torch/csrc/autograd/variable.h" #include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/fusion_compiler.h" #include "torch/csrc/jit/fusion_compiler.h"
#include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/ir.h"

View File

@ -7,6 +7,7 @@
#include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/variable_tensor_functions.h" #include "torch/csrc/variable_tensor_functions.h"

View File

@ -1058,6 +1058,34 @@ void testCustomOperators() {
REQUIRE(output[0] == 1.0); REQUIRE(output[0] == 1.0);
REQUIRE(output[1] == 2.0); REQUIRE(output[1] == 2.0);
} }
{
RegisterOperators reg(
"foo::lists2(Tensor[] tensors) -> Tensor[]",
[](std::vector<at::Tensor> tensors) { return tensors; });
auto& ops =
getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
REQUIRE(ops.size() == 1);
auto& op = ops.front();
REQUIRE(op->schema().name == "foo::lists2");
REQUIRE(op->schema().arguments.size() == 1);
REQUIRE(op->schema().arguments[0].name == "tensors");
REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofTensors()));
REQUIRE(op->schema().returns.size() == 1);
REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofTensors()));
Stack stack;
push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
op->getOperation()(stack);
std::vector<at::Tensor> output;
pop(stack, output);
REQUIRE(output.size() == 1);
REQUIRE(output[0].allclose(autograd::make_variable(at::ones(5))));
}
{ {
#ifdef USE_CATCH #ifdef USE_CATCH
REQUIRE_THROWS_WITH( REQUIRE_THROWS_WITH(

11
torch/op.h Normal file
View File

@ -0,0 +1,11 @@
#pragma once
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/custom_operator.h>
#include <ATen/ATen.h>
namespace torch {
using jit::createOperator;
using jit::RegisterOperators;
} // namespace torch