pytorch/test/custom_operator/test_custom_classes.py
Horace He f81db8afb8 Initial torchbind prototype (#21098)
Summary:
I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify `test_libtorch` to point to where you have `pytorch` built. I currently require that `pybind11` is included as a subdirectory of the test, but added it to the `.gitignore` to make this reviewable.

Currently, something like this works:
```cpp
struct Foo {
  int x, y;
  Foo(): x(2), y(5){}
  Foo(int x_, int y_) : x(x_), y(y_) {}
  void display() {
    cout<<"x: "<<x<<' '<<"y: "<<y<<endl;
  }
  int64_t add(int64_t z) {
    return (x+y)*z;
  }
};
static auto test = torch::jit::class_<Foo>("Foo")
                    .def(torch::jit::init<int64_t, int64_t>())
                    .def("display", &Foo::display)
                    .def("add", &Foo::add)
                    .def("combine", &Foo::combine);

```
with
```py
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val.display()
    print(val.add(3))
```
results in
```
x: 5 y: 3
24
```

Current issues:
- [x] The python class created by torchscript doesn't interactly properly with the surrounding code.
```
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    return val
```
- [x] Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe).
```cpp
  void combine(Foo x) {
```

- [x] Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object).
```py
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val2 = torch._C.Foo(100, 0)
    val.display()
    print(val.add(3))
```
- [ ] Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods).
- [x] `init` is a little bit different syntax than `pybind`. `.init<...>()` instead of `.def(py::init<>())`
- [x] I couldn't figure out how to add some files into the build so they'd be copied to the `include/` directories, so I symlinked them manually.
- [ ] Currently, the conversion from Python into Torchscript doesn't work.
- [ ] Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible.
- [ ] We pass back into Python by value, currently. There's no way of passing by reference.
- [x] Currently can only register one method with the same type signature. This is because we create a `static auto opRegistry`, and the function is templated on the type signature.

Somewhat blocked on https://github.com/pytorch/pytorch/pull/21177. We currently use some structures that will be refactored by his PR (namely `return_type_to_ivalue` and `ivalue_to_arg_type`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21098

Differential Revision: D16634872

Pulled By: Chillee

fbshipit-source-id: 1408bb89ea649c27d560df59e2cf9920467fe1de
2019-08-02 18:45:15 -07:00

81 lines
2.3 KiB
Python

import unittest
import torch
from torch import ops
import torch.jit as jit
import glob
import os
def get_custom_class_library_path():
library_filename = glob.glob("build/*custom_class*")
assert (len(library_filename) == 1)
library_filename = library_filename[0]
path = os.path.abspath(library_filename)
assert os.path.exists(path), path
return path
def test_equality(f, cmp_key):
obj1 = f()
obj2 = jit.script(f)()
return (cmp_key(obj1), cmp_key(obj2))
class TestCustomOperators(unittest.TestCase):
def setUp(self):
ops.load_library(get_custom_class_library_path())
def test_no_return_class(self):
def f():
val = torch.classes.Foo(5, 3)
return val.info()
self.assertEqual(*test_equality(f, lambda x: x))
def test_constructor_with_args(self):
def f():
val = torch.classes.Foo(5, 3)
return val
self.assertEqual(*test_equality(f, lambda x: x.info()))
def test_function_call_with_args(self):
def f():
val = torch.classes.Foo(5, 3)
val.increment(1)
return val
self.assertEqual(*test_equality(f, lambda x: x.info()))
def test_function_method_wrong_type(self):
def f():
val = torch.classes.Foo(5, 3)
val.increment("asdf")
return val
with self.assertRaisesRegex(RuntimeError, "Expected"):
jit.script(f)()
@unittest.skip("We currently don't support passing custom classes to custom methods.")
def test_input_class_type(self):
def f():
val = torch.classes.Foo(1, 2)
val2 = torch.classes.Foo(2, 3)
val.combine(val2)
return val
self.assertEqual(*test_equality(f, lambda x: x.info()))
def test_stack_string(self):
def f():
val = torch.classes.StackString(["asdf", "bruh"])
return val.pop()
self.assertEqual(*test_equality(f, lambda x: x))
def test_stack_push_pop(self):
def f():
val = torch.classes.StackString(["asdf", "bruh"])
val2 = torch.classes.StackString(["111", "222"])
val.push(val2.pop())
return val.pop() + val2.pop()
self.assertEqual(*test_equality(f, lambda x: x))
if __name__ == "__main__":
unittest.main()