[JIT] Fix @staticmethod access from self on modules (#37702)

Summary:
Closes https://github.com/pytorch/pytorch/issues/30755
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37702

Differential Revision: D21389989

Pulled By: voznesenskym

fbshipit-source-id: f9b7e26a9eab7dc3d7762a5a28f85424dac5fbb3
This commit is contained in:
Michael Voznesensky 2020-05-14 21:10:14 -07:00 committed by Facebook GitHub Bot
parent 3d0532f3ab
commit 960f4b51e3
3 changed files with 42 additions and 0 deletions

View File

@ -3518,6 +3518,30 @@ class TestScript(JitTestCase):
.check("aten::mul") \
.run(m.inlined_graph)
def test_static_method_on_module(self):
"""
Check that the `@staticmethod` annotation on a function on a module works.
"""
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
@staticmethod
def do_it(x, h):
new_h = torch.tanh(x + h)
return new_h, new_h
def forward(self, x, h):
return self.do_it(x, h)
my_cell = torch.jit.script(MyCell())
x = torch.rand(3, 4)
h = torch.rand(3, 4)
jitted_cell = my_cell(x, h)
non_jitted_cell = MyCell().do_it(x, h)
self.assertEqual(jitted_cell, non_jitted_cell)
def test_code_with_constants(self):
"""
Check that the `code_with_constants` property correctly returns graph CONSTANTS in the

View File

@ -453,6 +453,14 @@ def is_ignored_fn(fn):
mod = get_torchscript_modifier(fn)
return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE
def is_static_fn(cls, fn):
return isinstance(inspect.getattr_static(cls, fn), staticmethod)
def get_static_fn(cls, fn):
return inspect.getattr_static(cls, fn).__func__
def get_torchscript_modifier(fn):
if not callable(fn):
return None

View File

@ -426,7 +426,17 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
concreteType_->getPyClass(),
field.c_str(),
pybind11::cast<pybind11::none>(Py_None));
if (py::isinstance<py::function>(unboundMethod)) {
bool isStaticFn =
py::cast<bool>(py::module::import("torch._jit_internal")
.attr("is_static_fn")(concreteType_->getPyClass(), field.c_str()));
if (isStaticFn) {
// Functions within the module annotated with @staticmethod do not need binding.
py::object staticFn = py::module::import("torch._jit_internal")
.attr("get_static_fn")(concreteType_->getPyClass(), field.c_str());
return toSugaredValue(staticFn, m, loc);
}
// For Python methods that we're trying to call directly, we need to bind
// the method to a self. (see the documentation for lazy_bind in Python for
// more info).