diff --git a/test/test_jit.py b/test/test_jit.py index b929573f956..922d7595117 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3518,6 +3518,30 @@ class TestScript(JitTestCase): .check("aten::mul") \ .run(m.inlined_graph) + def test_static_method_on_module(self): + """ + Check that the `@staticmethod` annotation on a function on a module works. + """ + class MyCell(torch.nn.Module): + def __init__(self): + super(MyCell, self).__init__() + + @staticmethod + def do_it(x, h): + new_h = torch.tanh(x + h) + return new_h, new_h + + def forward(self, x, h): + return self.do_it(x, h) + + my_cell = torch.jit.script(MyCell()) + x = torch.rand(3, 4) + h = torch.rand(3, 4) + jitted_cell = my_cell(x, h) + non_jitted_cell = MyCell().do_it(x, h) + + self.assertEqual(jitted_cell, non_jitted_cell) + def test_code_with_constants(self): """ Check that the `code_with_constants` property correctly returns graph CONSTANTS in the diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index fdf506a4f99..3a18551a366 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -453,6 +453,14 @@ def is_ignored_fn(fn): mod = get_torchscript_modifier(fn) return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE + +def is_static_fn(cls, fn): + return isinstance(inspect.getattr_static(cls, fn), staticmethod) + +def get_static_fn(cls, fn): + return inspect.getattr_static(cls, fn).__func__ + + def get_torchscript_modifier(fn): if not callable(fn): return None diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index c31b5081119..a1a8cfa73c2 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -426,7 +426,17 @@ std::shared_ptr ModuleValue::tryGetAttr( concreteType_->getPyClass(), field.c_str(), pybind11::cast(Py_None)); + if (py::isinstance(unboundMethod)) { + bool isStaticFn = + py::cast(py::module::import("torch._jit_internal") + .attr("is_static_fn")(concreteType_->getPyClass(), field.c_str())); + if (isStaticFn) { + // Functions within the module annotated with @staticmethod do not need binding. + py::object staticFn = py::module::import("torch._jit_internal") + .attr("get_static_fn")(concreteType_->getPyClass(), field.c_str()); + return toSugaredValue(staticFn, m, loc); + } // For Python methods that we're trying to call directly, we need to bind // the method to a self. (see the documentation for lazy_bind in Python for // more info).