From 960f4b51e3403f14b407328823bd38e8e1067592 Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Thu, 14 May 2020 21:10:14 -0700 Subject: [PATCH] [JIT] Fix `@staticmethod` access from `self` on modules (#37702) Summary: Closes https://github.com/pytorch/pytorch/issues/30755 Pull Request resolved: https://github.com/pytorch/pytorch/pull/37702 Differential Revision: D21389989 Pulled By: voznesenskym fbshipit-source-id: f9b7e26a9eab7dc3d7762a5a28f85424dac5fbb3 --- test/test_jit.py | 24 +++++++++++++++++++ torch/_jit_internal.py | 8 +++++++ .../csrc/jit/python/python_sugared_value.cpp | 10 ++++++++ 3 files changed, 42 insertions(+) diff --git a/test/test_jit.py b/test/test_jit.py index b929573f956..922d7595117 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3518,6 +3518,30 @@ class TestScript(JitTestCase): .check("aten::mul") \ .run(m.inlined_graph) + def test_static_method_on_module(self): + """ + Check that the `@staticmethod` annotation on a function on a module works. + """ + class MyCell(torch.nn.Module): + def __init__(self): + super(MyCell, self).__init__() + + @staticmethod + def do_it(x, h): + new_h = torch.tanh(x + h) + return new_h, new_h + + def forward(self, x, h): + return self.do_it(x, h) + + my_cell = torch.jit.script(MyCell()) + x = torch.rand(3, 4) + h = torch.rand(3, 4) + jitted_cell = my_cell(x, h) + non_jitted_cell = MyCell().do_it(x, h) + + self.assertEqual(jitted_cell, non_jitted_cell) + def test_code_with_constants(self): """ Check that the `code_with_constants` property correctly returns graph CONSTANTS in the diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index fdf506a4f99..3a18551a366 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -453,6 +453,14 @@ def is_ignored_fn(fn): mod = get_torchscript_modifier(fn) return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE + +def is_static_fn(cls, fn): + return isinstance(inspect.getattr_static(cls, fn), staticmethod) + +def get_static_fn(cls, fn): + return inspect.getattr_static(cls, fn).__func__ + + def get_torchscript_modifier(fn): if not callable(fn): return None diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index c31b5081119..a1a8cfa73c2 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -426,7 +426,17 @@ std::shared_ptr ModuleValue::tryGetAttr( concreteType_->getPyClass(), field.c_str(), pybind11::cast(Py_None)); + if (py::isinstance(unboundMethod)) { + bool isStaticFn = + py::cast(py::module::import("torch._jit_internal") + .attr("is_static_fn")(concreteType_->getPyClass(), field.c_str())); + if (isStaticFn) { + // Functions within the module annotated with @staticmethod do not need binding. + py::object staticFn = py::module::import("torch._jit_internal") + .attr("get_static_fn")(concreteType_->getPyClass(), field.c_str()); + return toSugaredValue(staticFn, m, loc); + } // For Python methods that we're trying to call directly, we need to bind // the method to a self. (see the documentation for lazy_bind in Python for // more info).