From f81db8afb84d0365055d07383bf63f71dbc55ea5 Mon Sep 17 00:00:00 2001 From: Horace He Date: Fri, 2 Aug 2019 18:41:34 -0700 Subject: [PATCH] Initial torchbind prototype (#21098) Summary: I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify `test_libtorch` to point to where you have `pytorch` built. I currently require that `pybind11` is included as a subdirectory of the test, but added it to the `.gitignore` to make this reviewable. Currently, something like this works: ```cpp struct Foo { int x, y; Foo(): x(2), y(5){} Foo(int x_, int y_) : x(x_), y(y_) {} void display() { cout<<"x: "<("Foo") .def(torch::jit::init()) .def("display", &Foo::display) .def("add", &Foo::add) .def("combine", &Foo::combine); ``` with ```py torch.jit.script def f(x): val = torch._C.Foo(5, 3) val.display() print(val.add(3)) ``` results in ``` x: 5 y: 3 24 ``` Current issues: - [x] The python class created by torchscript doesn't interactly properly with the surrounding code. ``` torch.jit.script def f(x): val = torch._C.Foo(5, 3) return val ``` - [x] Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe). ```cpp void combine(Foo x) { ``` - [x] Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object). ```py torch.jit.script def f(x): val = torch._C.Foo(5, 3) val2 = torch._C.Foo(100, 0) val.display() print(val.add(3)) ``` - [ ] Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods). - [x] `init` is a little bit different syntax than `pybind`. `.init<...>()` instead of `.def(py::init<>())` - [x] I couldn't figure out how to add some files into the build so they'd be copied to the `include/` directories, so I symlinked them manually. - [ ] Currently, the conversion from Python into Torchscript doesn't work. - [ ] Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible. - [ ] We pass back into Python by value, currently. There's no way of passing by reference. - [x] Currently can only register one method with the same type signature. This is because we create a `static auto opRegistry`, and the function is templated on the type signature. Somewhat blocked on https://github.com/pytorch/pytorch/pull/21177. We currently use some structures that will be refactored by his PR (namely `return_type_to_ivalue` and `ivalue_to_arg_type`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/21098 Differential Revision: D16634872 Pulled By: Chillee fbshipit-source-id: 1408bb89ea649c27d560df59e2cf9920467fe1de --- .jenkins/pytorch/macos-test.sh | 1 + .jenkins/pytorch/test.sh | 1 + .../test_custom_script_ops.bat | 2 + aten/src/ATen/core/ivalue.cpp | 13 ++ aten/src/ATen/core/ivalue.h | 34 ++- aten/src/ATen/core/ivalue_inl.h | 95 ++++++++ aten/src/ATen/core/jit_type.h | 37 +++- .../core/op_registration/kernel_functor.h | 13 +- aten/src/ATen/core/type.cpp | 4 + c10/util/C++17.h | 17 ++ caffe2/CMakeLists.txt | 2 +- test/custom_operator/CMakeLists.txt | 5 + test/custom_operator/classes.cpp | 65 ++++++ test/custom_operator/test_custom_classes.py | 80 +++++++ torch/CMakeLists.txt | 2 +- torch/__init__.py | 1 + torch/_classes.py | 9 + torch/_jit_internal.py | 4 + torch/_ops.py | 1 - torch/csrc/THP_export.h | 12 +- torch/csrc/jit/node_hashing.cpp | 2 + torch/csrc/jit/pybind_utils.h | 18 ++ torch/csrc/jit/script/schema_type_parser.cpp | 2 + torch/custom_class.h | 207 ++++++++++++++++++ 24 files changed, 607 insertions(+), 20 deletions(-) create mode 100644 test/custom_operator/classes.cpp create mode 100644 test/custom_operator/test_custom_classes.py create mode 100644 torch/_classes.py create mode 100644 torch/custom_class.h diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 0cab621a565..30482f5b607 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -116,6 +116,7 @@ test_custom_script_ops() { # Run tests Python-side and export a script module. python test_custom_ops.py -v + python test_custom_classes.py -v python model.py --export-script-module=model.pt # Run tests C++-side and load the exported script module. build/test_custom_ops ./model.pt diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 667bd49ea93..a512da2c370 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -162,6 +162,7 @@ test_custom_script_ops() { cp -a "$CUSTOM_OP_BUILD" build # Run tests Python-side and export a script module. python test_custom_ops.py -v + python test_custom_classes.py -v python model.py --export-script-module=model.pt # Run tests C++-side and load the exported script module. build/test_custom_ops ./model.pt diff --git a/.jenkins/pytorch/win-test-helpers/test_custom_script_ops.bat b/.jenkins/pytorch/win-test-helpers/test_custom_script_ops.bat index d86692dbabb..90c0e5de75a 100644 --- a/.jenkins/pytorch/win-test-helpers/test_custom_script_ops.bat +++ b/.jenkins/pytorch/win-test-helpers/test_custom_script_ops.bat @@ -1,5 +1,6 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat +git submodule update --init --recursive third_party/pybind11 cd test\custom_operator :: Build the custom operator library. @@ -23,6 +24,7 @@ popd :: Run tests Python-side and export a script module. python test_custom_ops.py -v +python test_custom_classes.py -v python model.py --export-script-module="build/model.pt" :: Run tests C++-side and load the exported script module. cd build diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 83d5b333d04..788d3534fc2 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -102,6 +102,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { return printList(out, v.toTensorList(), "[", "]"); case IValue::Tag::Blob: return out << *v.toBlob(); + case IValue::Tag::Capsule: + return out << "Capsule"; case IValue::Tag::GenericList: return printList(out, v.toGenericList(), "[", "]"); case IValue::Tag::Future: @@ -170,4 +172,15 @@ std::vector> iterationOrder(const c10::Dict& getCustomClassTypeMap() { + static std::unordered_map tmap; + return tmap; +} + +std::unordered_map>& +getClassConverter() { + static std::unordered_map> + classConverter; + return classConverter; +} } // namespace c10 diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 5ad7def0973..d7ba874a3ef 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -3,9 +3,11 @@ #include #include #include +#include namespace torch { namespace jit { +class CustomClassHolder : public c10::intrusive_ptr_target {}; struct Function; namespace script { struct CompilationUnit; @@ -49,8 +51,10 @@ struct Object; _(GenericDict) \ _(Future) \ _(Device) \ + _(Object) \ _(Uninitialized) \ - _(Object) + _(Capsule) \ + struct CAFFE2_API IValue final { IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {} @@ -148,6 +152,14 @@ struct CAFFE2_API IValue final { c10::intrusive_ptr toBlob() &&; c10::intrusive_ptr toBlob() const &; + // Capsule + IValue(intrusive_ptr blob); + bool isCapsule() const { + return Tag::Capsule == tag; + } + c10::intrusive_ptr toCapsule() &&; + c10::intrusive_ptr toCapsule() const &; + // Tuple IValue(c10::intrusive_ptr v); bool isTuple() const { return Tag::Tuple == tag; } @@ -564,6 +576,26 @@ struct StrongTypePtr { std::shared_ptr cu_; std::shared_ptr type_; }; + +TORCH_API std::unordered_map& getCustomClassTypeMap(); +template +c10::StrongTypePtr getCustomClassType() { + auto tmap = c10::getCustomClassTypeMap(); + auto res = tmap.find(typeid(T).name()); + if (res == tmap.end()) { + throw c10::Error("Can't find class id in custom class type map", ""); + } + return res->second; +} + +template +inline bool isCustomClassRegistered() { + auto tmap = c10::getCustomClassTypeMap(); + return tmap.find(typeid(T).name()) != tmap.end(); +} + +TORCH_API std::unordered_map>& +getClassConverter(); } #include diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index f2a5dd5f960..c476f0d36bc 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -24,6 +24,21 @@ struct IValue; struct ClassType; struct TupleType; +// For custom class __init__ registration, we need to pass in a function +// that looks like this: [](IValue x, args...) + +// However, kernel_functor.h automatically sets the input types of the function +// by introspecting the types of the functor (which is IValue in this case). +// However, we need the type it binds to be Foo. + +// Instead, we pass in a lambda [](ivalue_holder x, args...) from +// which getTypePtr can recover the original class pointer. + +template +struct tagged_capsule { + IValue ivalue; +}; + template c10::intrusive_ptr IValue::moveToIntrusivePtr() { auto t = c10::intrusive_ptr::reclaim(static_cast(payload.as_intrusive_ptr)); @@ -38,6 +53,11 @@ c10::intrusive_ptr IValue::toIntrusivePtr() const { return p; } +template +intrusive_ptr static_intrusive_pointer_cast(intrusive_ptr r) { + return intrusive_ptr::reclaim(static_cast(r.release())); +} + inline c10::intrusive_ptr IValue::toFuture() && { AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); return moveToIntrusivePtr(); @@ -78,6 +98,14 @@ inline c10::intrusive_ptr IValue::toBlob() const & { AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); return toIntrusivePtr();; } +inline c10::intrusive_ptr IValue::toCapsule() && { + TORCH_INTERNAL_ASSERT(isCapsule()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toCapsule() const & { + TORCH_INTERNAL_ASSERT(isCapsule()); + return toIntrusivePtr(); +} namespace ivalue { @@ -430,6 +458,23 @@ std::vector generic_to( return result; } +template +T generic_to( + IValue ivalue, + _fake_type) { + using ElemType = typename std::remove_pointer::type::element_type; + auto obj = ivalue.toObject(); + auto capsule = obj->getSlot(0); + return c10::static_intrusive_pointer_cast(capsule.toCapsule()); +} + +template +tagged_capsule generic_to( + IValue ivalue, + _fake_type>) { + return tagged_capsule{ivalue}; +} + template c10::List generic_to( IValue ivalue, @@ -640,6 +685,10 @@ inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Object), is_intrusive_ptr(true) { payload.as_intrusive_ptr = v.release(); } +inline IValue::IValue(c10::intrusive_ptr v) +: tag(Tag::Capsule), is_intrusive_ptr(true) { + payload.as_intrusive_ptr = v.release(); +} inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Future), is_intrusive_ptr(true) { payload.as_intrusive_ptr = v.release(); @@ -687,4 +736,50 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const { } } +namespace ivalue { +namespace detail { +// This code allows us to template on a function based on whether IValue has a +// constructor for it. Specifically, has_constructor{} inherits from std::true_type if +// IValue(T) compiles, and inherits from std::false_type if IValue(T) doesn't. +// We use it for calling the IValue constructor for `from` if it exists, and otherwise +// attempt to use our custom class code. +template struct type_sink { typedef void type; }; +template using type_sink_t = typename type_sink::type; +template struct has_constructor : std::false_type {}; \ +template struct has_constructor< + T, + type_sink_t< decltype( IValue(std::declval())) > +>: std::true_type {}; + +template +IValue from_(T x, std::true_type) { + return IValue(x); +} +template +IValue from_(c10::intrusive_ptr x, std::false_type) { + using inputType = c10::intrusive_ptr; + if (!isCustomClassRegistered()) { + throw c10::Error("Trying to return a class that we don't support and isn't a registered custom class.", ""); + } + auto res = getCustomClassType(); + auto retObject = ivalue::Object::create(res->second, 1); + auto objPtr = c10::static_intrusive_pointer_cast(x); + + retObject->setSlot(0, IValue(objPtr)); + auto resIVal = IValue(std::move(retObject)); + return resIVal; +} +template +IValue from_(T x, std::false_type) { + static_assert(guts::false_t::value, "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)"); + return IValue(); +} +} + +template +IValue from(T x) { + return detail::from_(x, detail::has_constructor{}); +} + +} } // namespace c10 diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 3b8044adfa7..97bc44ce3ac 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -13,6 +13,7 @@ #include #include +struct ClassType; namespace torch { namespace jit { struct Function; @@ -48,7 +49,8 @@ using OptNameList = c10::optional>; _(ProfiledTensorType) \ _(DeviceObjType) \ _(FunctionType) \ - _(ClassType) + _(ClassType) \ + _(CapsuleType) enum class TypeKind { #define DEFINE_TYPE(T) T, @@ -1304,6 +1306,28 @@ struct VarType : public Type { std::string name_; }; +struct CapsuleType; +using CapsuleTypePtr = std::shared_ptr; +// This type represents a Python Capsule +struct CAFFE2_API CapsuleType : public Type { + static CapsuleTypePtr create() { + return CapsuleTypePtr(new CapsuleType()); // NOLINT(modernize-make-shared) + } + DEFINE_IS_SUBCLASS(CapsuleType); + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Capsule"; + } + static const TypeKind Kind = TypeKind::CapsuleType; + // global singleton + static CapsuleTypePtr get(); +private: + CapsuleType() + : Type(TypeKind::CapsuleType) {} +}; + CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t); CAFFE2_API std::ostream& operator<<(std::ostream& out, const VaryingShape& t); // what is the type, ignoring extra size/shape information? @@ -1359,9 +1383,13 @@ CAFFE2_API c10::optional unifyTypes( namespace detail { template struct getTypePtr_ final { - static_assert( - guts::false_t::value, - "Type could not be converted to any of the known types."); + static TypePtr call() { + if (!isCustomClassRegistered()) { + throw c10::Error("Type could not be converted to any of the known types.", ""); + } + auto res = getCustomClassType(); + return std::dynamic_pointer_cast(res.type_); + } }; template <> @@ -1633,4 +1661,5 @@ struct CAFFE2_API ClassType : public NamedType { // List of methods associated with this class. std::vector methods_; }; + } // namespace c10 diff --git a/aten/src/ATen/core/op_registration/kernel_functor.h b/aten/src/ATen/core/op_registration/kernel_functor.h index fa09b61297f..77e4a410c72 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor.h +++ b/aten/src/ATen/core/op_registration/kernel_functor.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace c10 { /** @@ -37,7 +38,10 @@ namespace detail { >; template struct assert_is_valid_input_type { - static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported input type."); + assert_is_valid_input_type() { + auto tmap = c10::getCustomClassTypeMap(); + TORCH_CHECK(c10::isCustomClassRegistered(), "Tried to use undefined class as input argument"); + } }; template @@ -98,7 +102,10 @@ namespace detail { }; template struct assert_is_valid_output_type { - static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported output type."); + assert_is_valid_output_type() { + auto tmap = getCustomClassTypeMap(); + TORCH_CHECK(c10::isCustomClassRegistered(), "Tried to use undefined class as output"); + } }; template @@ -170,7 +177,7 @@ namespace detail { template IValue return_to_ivalue(T&& v) { assert_is_valid_output_type(); - return IValue(std::move(v)); + return c10::ivalue::from(v); } template diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 11d79a5b490..197fd17e64c 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -119,6 +119,10 @@ OptionalTypePtr OptionalType::ofTensor() { static auto value = OptionalType::create(TensorType::get()); return value; } +CapsuleTypePtr CapsuleType::get() { + static auto value = CapsuleType::create(); + return value; +} ListTypePtr ListType::ofTensors() { static auto value = ListType::create(TensorType::get()); return value; diff --git a/c10/util/C++17.h b/c10/util/C++17.h index 8249e73844d..41f8fe00f30 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -8,6 +8,7 @@ #include #include #include +#include #include /* @@ -229,6 +230,21 @@ constexpr auto apply(F&& f, Tuple&& t) -> decltype(detail::apply_impl( #endif #endif +template +typename std::enable_if< + std::is_member_pointer::type>::value, + typename std::result_of::type>::type +invoke(Functor&& f, Args&&... args) { + return std::mem_fn(f)(std::forward(args)...); +} + +template +typename std::enable_if< + !std::is_member_pointer::type>::value, + typename std::result_of::type>::type +invoke(Functor&& f, Args&&... args) { + return std::forward(f)(std::forward(args)...); +} @@ -243,6 +259,7 @@ namespace std { // std::to_string() call, then you're calling std::to_string() but should be calling // c10::guts::to_string(). inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; } + } namespace c10 { namespace guts { namespace detail { diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 3e0d93d54f3..e221b4fb806 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -716,7 +716,7 @@ ENDIF() install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h") - install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" + install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch) diff --git a/test/custom_operator/CMakeLists.txt b/test/custom_operator/CMakeLists.txt index 65a4ac35a72..38fd42c4910 100644 --- a/test/custom_operator/CMakeLists.txt +++ b/test/custom_operator/CMakeLists.txt @@ -5,6 +5,11 @@ project(custom_ops) find_package(Torch REQUIRED) add_library(custom_ops SHARED op.cpp) + +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11/ ./pybind11) +pybind11_add_module(custom_class SHARED classes.cpp) +target_link_libraries(custom_class PRIVATE "${TORCH_LIBRARIES}") + target_compile_features(custom_ops PUBLIC cxx_range_for) target_link_libraries(custom_ops "${TORCH_LIBRARIES}") target_compile_definitions(custom_ops PRIVATE custom_ops_EXPORTS) diff --git a/test/custom_operator/classes.cpp b/test/custom_operator/classes.cpp new file mode 100644 index 00000000000..6e9a44baae1 --- /dev/null +++ b/test/custom_operator/classes.cpp @@ -0,0 +1,65 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace py = pybind11; + +struct Foo : torch::jit::CustomClassHolder { + int x, y; + Foo(): x(0), y(0){} + Foo(int x_, int y_) : x(x_), y(y_) {} + int64_t info() { + return this->x * this->y; + } + int64_t add(int64_t z) { + return (x+y)*z; + } + void increment(int64_t z) { + this->x+=z; + this->y+=z; + } + int64_t combine(c10::intrusive_ptr b) { + return this->info() + b->info(); + } + ~Foo() { + // std::cout<<"Destroying object with values: "< struct Stack : torch::jit::CustomClassHolder { + std::vector stack_; + Stack(std::vector init): stack_(init.begin(), init.end()) {} + + void push(T x) { + stack_.push_back(x); + } + T pop() { + auto val = stack_.back(); + stack_.pop_back(); + return val; + } +}; + +static auto test = torch::jit::class_("Foo") + .def(torch::jit::init()) + // .def(torch::jit::init<>()) + .def("info", &Foo::info) + .def("increment", &Foo::increment) + // .def("add", &Foo::add); + .def("combine", &Foo::combine) + ; + +static auto testStack = torch::jit::class_>("StackString") + .def(torch::jit::init>()) + .def("push", &Stack::push) + .def("pop", &Stack::pop) + ; diff --git a/test/custom_operator/test_custom_classes.py b/test/custom_operator/test_custom_classes.py new file mode 100644 index 00000000000..2d84bcbf525 --- /dev/null +++ b/test/custom_operator/test_custom_classes.py @@ -0,0 +1,80 @@ +import unittest +import torch +from torch import ops +import torch.jit as jit +import glob +import os + +def get_custom_class_library_path(): + library_filename = glob.glob("build/*custom_class*") + assert (len(library_filename) == 1) + library_filename = library_filename[0] + path = os.path.abspath(library_filename) + assert os.path.exists(path), path + return path + +def test_equality(f, cmp_key): + obj1 = f() + obj2 = jit.script(f)() + return (cmp_key(obj1), cmp_key(obj2)) + +class TestCustomOperators(unittest.TestCase): + def setUp(self): + ops.load_library(get_custom_class_library_path()) + + def test_no_return_class(self): + def f(): + val = torch.classes.Foo(5, 3) + return val.info() + self.assertEqual(*test_equality(f, lambda x: x)) + + def test_constructor_with_args(self): + def f(): + val = torch.classes.Foo(5, 3) + return val + self.assertEqual(*test_equality(f, lambda x: x.info())) + + def test_function_call_with_args(self): + def f(): + val = torch.classes.Foo(5, 3) + val.increment(1) + return val + + self.assertEqual(*test_equality(f, lambda x: x.info())) + + def test_function_method_wrong_type(self): + def f(): + val = torch.classes.Foo(5, 3) + val.increment("asdf") + return val + + with self.assertRaisesRegex(RuntimeError, "Expected"): + jit.script(f)() + + @unittest.skip("We currently don't support passing custom classes to custom methods.") + def test_input_class_type(self): + def f(): + val = torch.classes.Foo(1, 2) + val2 = torch.classes.Foo(2, 3) + val.combine(val2) + return val + + self.assertEqual(*test_equality(f, lambda x: x.info())) + + def test_stack_string(self): + def f(): + val = torch.classes.StackString(["asdf", "bruh"]) + return val.pop() + self.assertEqual(*test_equality(f, lambda x: x)) + + def test_stack_push_pop(self): + def f(): + val = torch.classes.StackString(["asdf", "bruh"]) + val2 = torch.classes.StackString(["111", "222"]) + val.push(val2.pop()) + return val.pop() + val2.pop() + self.assertEqual(*test_equality(f, lambda x: x)) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index f8f9c8c775e..bbb6f8f8b8f 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -244,7 +244,7 @@ if (USE_NCCL) endif() # In the most recent CMake versions, a new 'TRANSFORM' subcommand of 'list' allows much of the boilerplate of defining the lists -# of type stub files to be omitted. +# of type stub files to be omitted. # For comptability with older CMake versions, we omit it for now, but leave it as a comment in case comptability with the older # CMake versions is eventually dropped. # set(Modules diff --git a/torch/__init__.py b/torch/__init__.py index 7206177114e..b34e3c6a830 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -336,6 +336,7 @@ def compiled_with_cxx11_abi(): # Import the ops "namespace" from torch._ops import ops # noqa: F401 +from torch._classes import classes # noqa: F401 # Import the quasi random sampler import torch.quasirandom diff --git a/torch/_classes.py b/torch/_classes.py new file mode 100644 index 00000000000..ff7ae6c2821 --- /dev/null +++ b/torch/_classes.py @@ -0,0 +1,9 @@ +import types + +class _Classes(types.ModuleType): + def __init__(self): + super(_Classes, self).__init__('torch.classes') + + +# The classes "namespace" +classes = _Classes() diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 8caecbba20f..18fd6513b1f 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -397,6 +397,10 @@ def _qualified_name(obj): name = obj.__name__ module_name = obj.__module__ + # If the module is actually a torchbind module, then we should short circuit + if module_name == "torch._classes": + return obj.qualified_name + # The Python docs are very clear that `__module__` can be None, but I can't # figure out when it actually would be. if module_name is None: diff --git a/torch/_ops.py b/torch/_ops.py index 66977839655..01c64be5ced 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -106,6 +106,5 @@ class _Ops(types.ModuleType): ctypes.CDLL(path) self.loaded_libraries.add(path) - # The ops "namespace" ops = _Ops() diff --git a/torch/csrc/THP_export.h b/torch/csrc/THP_export.h index 02aa322b64e..b7d77c624b0 100644 --- a/torch/csrc/THP_export.h +++ b/torch/csrc/THP_export.h @@ -1,22 +1,16 @@ #ifndef THP_EXPORT_H #define THP_EXPORT_H -#ifdef __cplusplus -# define THP_EXTERNC extern "C" -#else -# define THP_EXTERNC extern -#endif - #ifdef _WIN32 # ifdef _THP_CORE -# define THP_API THP_EXTERNC __declspec(dllexport) +# define THP_API extern __declspec(dllexport) # define THP_CLASS __declspec(dllexport) # else -# define THP_API THP_EXTERNC __declspec(dllimport) +# define THP_API extern __declspec(dllimport) # define THP_CLASS __declspec(dllimport) # endif #else -# define THP_API THP_EXTERNC +# define THP_API extern # define THP_CLASS #endif diff --git a/torch/csrc/jit/node_hashing.cpp b/torch/csrc/jit/node_hashing.cpp index 05cf2e9f2ac..4e33a9c0f9c 100644 --- a/torch/csrc/jit/node_hashing.cpp +++ b/torch/csrc/jit/node_hashing.cpp @@ -112,6 +112,8 @@ bool EqualNode::operator()(const Node* lhs, const Node* rhs) const { for (size_t i = 0; i < lhs_outputs.size(); ++i) { if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type()) return false; + if (lhs_outputs[i]->type() == CapsuleType::get()) + return false; } // Check whether the inputs are the same. diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index ce0e345a897..3be05afa28b 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -448,6 +449,8 @@ inline IValue toIValue( break; case TypeKind::FunctionType: AT_ERROR("Function Values aren't yet supported"); + case TypeKind::CapsuleType: + AT_ERROR("Capsule Values aren't supported"); } AT_ERROR( "Missing cases in toIValue for type: ", @@ -510,6 +513,17 @@ inline IValue returnToIValue(const TypePtr& type, py::handle object) { } } +inline c10::optional tryToConvertToCustomClass( + const c10::intrusive_ptr& obj) { + if (obj->name().find("__torch__.torch.classes") == 0) { + auto objPtr = (void*)obj->getSlot(0).toCapsule().release(); + auto classConverter = c10::getClassConverter()[obj->name()]; + py::handle rawPyObj = classConverter(objPtr); + auto o = py::reinterpret_steal(rawPyObj); + return o; + } + return c10::nullopt; +} inline py::object toPyObject(IValue&& ivalue) { if (ivalue.isNone()) { return py::none(); @@ -573,6 +587,10 @@ inline py::object toPyObject(IValue&& ivalue) { } else if (ivalue.isObject()) { const auto obj = std::move(ivalue).toObject(); auto pyCu = get_python_cu(); + auto res = tryToConvertToCustomClass(obj); + if (res.has_value()) { + return res.value(); + } const auto classType = pyCu->get_class(c10::QualifiedName(obj->name())); AT_ASSERT(classType); auto pyClass = diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp index c63b32eed75..898b57f697b 100644 --- a/torch/csrc/jit/script/schema_type_parser.cpp +++ b/torch/csrc/jit/script/schema_type_parser.cpp @@ -19,6 +19,7 @@ using c10::GeneratorType; using c10::IntType; using c10::ListType; using c10::NoneType; +using c10::CapsuleType; using c10::NumberType; using c10::OptionalType; using c10::StringType; @@ -45,6 +46,7 @@ TypeAndAlias SchemaTypeParser::parseBaseType() { {"int", IntType::get()}, {"bool", BoolType::get()}, {"None", NoneType::get()}, + {"Capsule", CapsuleType::get()}, }; auto tok = L.cur(); if (!L.nextIf(TK_NONE)) { diff --git a/torch/custom_class.h b/torch/custom_class.h new file mode 100644 index 00000000000..77444ddd55e --- /dev/null +++ b/torch/custom_class.h @@ -0,0 +1,207 @@ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace py = pybind11; +namespace torch { +namespace jit { + +static std::vector registeredOps; + +namespace detail { +template +struct types { + constexpr static bool hasRet = true; + using type = types; +}; +template +struct types { + constexpr static bool hasRet = false; + using type = types; +}; +template +struct args; +template +struct args : types {}; +template +using args_t = typename args::type; +} // namespace detail +template +detail::types init() { return detail::types{}; } + +// To bind custom classes into Torchscript, use an API very similar to Pybind's. +// Currently exposes one class `torch::jit::class_` and 2 methods. +// - Constructing `torch::jit::class_` registers `Foo` in Python and +// Torchscript, and puts it under `torch.classes.Foo` in Python. +// - torch::jit::class_.def("method1", &Foo::method1) does some template +// metaprogramming to introspect the function types and register the operator +// for use in Torchscript. +// - torch::jit::class_.def(torch::jit::init()) registers +// the Foo(int, int) constructor. +// see test/custom_operator/classes.cpp and +// test/custom_operator/test_custom_classes.py for example usages + +template +class class_ { + std::string className; + std::string qualClassName; + c10::optional> pyClass = c10::nullopt; + std::shared_ptr classCu = nullptr; + ClassTypePtr classTypePtr; + + const std::string parentModule = "classes"; + const std::string topModule = "__torch__.torch"; + + public: + class_(string className_) : className(std::move(className_)) { + // Currently we register everything as a python class just for convenience. + // We'll want to remove this at some point to get rid of the python + // dependency. It would require significant changes to class registration, + // (I think)? + qualClassName = topModule + "." + parentModule + "." + className; + + auto obj = py::module::import("torch").attr(parentModule.c_str()); + pyClass = py::class_(obj, className.c_str()); + pyClass->attr("qualified_name") = py::str(qualClassName); + auto newClass = + py::module::import("torch.jit") + .attr("_add_script_class")(*pyClass, qualClassName.c_str()); + + auto castToPython = [](void* objPtr) -> PyObject* { + CurClass x = *static_cast(objPtr); + auto py_object = py::cast(x); + PyObject* rawPyObj = py_object.release().ptr(); + return rawPyObj; + }; + getClassConverter()[qualClassName] = castToPython; + + // We currently represent custom classes as torchscript classes with a + // capsule attribute + classCu = torch::jit::get_python_cu(); + classTypePtr = + ClassType::create(c10::QualifiedName(qualClassName), classCu); + classTypePtr->addAttribute("capsule", CapsuleType::get()); + + c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr).name(), + StrongTypePtr(classCu, classTypePtr)}); + c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule).name(), + StrongTypePtr(classCu, classTypePtr)}); + + classCu->register_class(classTypePtr); + } + + template + class_& def(detail::types) { // Used in combination with + // torch::jit::init<...>() + pyClass->def(py::init()); + + auto func = [](c10::tagged_capsule self, Types... args) { + auto classObj = c10::make_intrusive(args...); + auto genericPtr = c10::static_intrusive_pointer_cast(classObj); + auto capsule = IValue(genericPtr); + auto object = self.ivalue.toObject(); + object->setSlot(0, capsule); + }; + + defineMethod("__init__", std::move(func), false); + return *this; + } + template + class_& def(string name, Func f) { + auto res = def_(name, f, detail::args_t{}); + return *this; + } + + private: + template + struct addInput { + static Value* call(std::shared_ptr graph) { + return graph->addInput()->setType(getTypePtr()); + } + }; + template + std::vector addInputs_( + Func f, + std::shared_ptr graph, + guts::index_sequence) { + using argTypes = + typename guts::infer_function_traits_t::parameter_types; + std::vector res = { + addInput>::call( + graph)...}; + return res; + } + template + std::vector addInputs(Func f, std::shared_ptr graph) { + constexpr auto numArgs = + guts::infer_function_traits_t::number_of_parameters; + return addInputs_(f, graph, guts::make_index_sequence()); + } + + template + std::string type_name() { + return std::string(typeid(Last).name()); + } + template + std::string type_name() { + return type_name() + "_" + type_name(); + } + + template + void addType(Value* v) { + v->setType(getTypePtr()); + } + template + void defineMethod(std::string name, Func func, bool hasRet) { + auto graph = std::make_shared(); + auto qualFuncName = className + "::" + name; + registeredOps.push_back( + torch::RegisterOperators().op(qualFuncName, std::move(func))); + + + std::vector inputs = addInputs(func, graph); + auto methodCall = graph->insertNode(graph->create( + Symbol::fromQualString(qualFuncName), inputs, hasRet)); + Value* res; + if (hasRet) { + res = methodCall->output(); + addType(res); + } else { + res = graph->insertConstant(IValue())->setType(NoneType::get()); + } + graph->registerOutput(res); + + classCu->create_function(qualClassName + "." + name, graph); + } + template + class_& def_(string name, Func f, detail::types funcInfo) { + pyClass->def(name.c_str(), f); + + auto func = [f](c10::intrusive_ptr cur, Types... args) { + return guts::invoke(f, *cur, args...); + }; + defineMethod(name, std::move(func), funcInfo.hasRet); + return *this; + } +}; + +} // namespace jit + +} // namespace torch