mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[JIT] Fix @staticmethod access from self on modules (#37702)
Summary: Closes https://github.com/pytorch/pytorch/issues/30755 Pull Request resolved: https://github.com/pytorch/pytorch/pull/37702 Differential Revision: D21389989 Pulled By: voznesenskym fbshipit-source-id: f9b7e26a9eab7dc3d7762a5a28f85424dac5fbb3
This commit is contained in:
parent
3d0532f3ab
commit
960f4b51e3
|
|
@ -3518,6 +3518,30 @@ class TestScript(JitTestCase):
|
|||
.check("aten::mul") \
|
||||
.run(m.inlined_graph)
|
||||
|
||||
def test_static_method_on_module(self):
|
||||
"""
|
||||
Check that the `@staticmethod` annotation on a function on a module works.
|
||||
"""
|
||||
class MyCell(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyCell, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def do_it(x, h):
|
||||
new_h = torch.tanh(x + h)
|
||||
return new_h, new_h
|
||||
|
||||
def forward(self, x, h):
|
||||
return self.do_it(x, h)
|
||||
|
||||
my_cell = torch.jit.script(MyCell())
|
||||
x = torch.rand(3, 4)
|
||||
h = torch.rand(3, 4)
|
||||
jitted_cell = my_cell(x, h)
|
||||
non_jitted_cell = MyCell().do_it(x, h)
|
||||
|
||||
self.assertEqual(jitted_cell, non_jitted_cell)
|
||||
|
||||
def test_code_with_constants(self):
|
||||
"""
|
||||
Check that the `code_with_constants` property correctly returns graph CONSTANTS in the
|
||||
|
|
|
|||
|
|
@ -453,6 +453,14 @@ def is_ignored_fn(fn):
|
|||
mod = get_torchscript_modifier(fn)
|
||||
return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE
|
||||
|
||||
|
||||
def is_static_fn(cls, fn):
|
||||
return isinstance(inspect.getattr_static(cls, fn), staticmethod)
|
||||
|
||||
def get_static_fn(cls, fn):
|
||||
return inspect.getattr_static(cls, fn).__func__
|
||||
|
||||
|
||||
def get_torchscript_modifier(fn):
|
||||
if not callable(fn):
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -426,7 +426,17 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
|
|||
concreteType_->getPyClass(),
|
||||
field.c_str(),
|
||||
pybind11::cast<pybind11::none>(Py_None));
|
||||
|
||||
if (py::isinstance<py::function>(unboundMethod)) {
|
||||
bool isStaticFn =
|
||||
py::cast<bool>(py::module::import("torch._jit_internal")
|
||||
.attr("is_static_fn")(concreteType_->getPyClass(), field.c_str()));
|
||||
if (isStaticFn) {
|
||||
// Functions within the module annotated with @staticmethod do not need binding.
|
||||
py::object staticFn = py::module::import("torch._jit_internal")
|
||||
.attr("get_static_fn")(concreteType_->getPyClass(), field.c_str());
|
||||
return toSugaredValue(staticFn, m, loc);
|
||||
}
|
||||
// For Python methods that we're trying to call directly, we need to bind
|
||||
// the method to a self. (see the documentation for lazy_bind in Python for
|
||||
// more info).
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user