[JIT] Constant prop getattr (#49806)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49806

Fix for https://github.com/pytorch/pytorch/issues/47089

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D25696791

Pulled By: eellison

fbshipit-source-id: 914c17b8effef7f4f341775ac2b8150ee4703efd
This commit is contained in:
Elias Ellison 2020-12-28 10:41:28 -08:00 committed by Facebook GitHub Bot
parent 268441c7d8
commit fc559bd6dc
2 changed files with 34 additions and 0 deletions

View File

@ -1331,3 +1331,33 @@ class TestFreezing(JitTestCase):
m.eval()
with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"):
mf = torch._C._freeze_module(m._c)
def test_freeze_non_module_class_getattr(self):
class BoxCoder(object):
def __init__(self, bbox_xform_clip):
# type: (float) -> None
self.bbox_xform_clip = bbox_xform_clip
def decode(self, input):
return input * self.bbox_xform_clip
class MyModule(torch.nn.Module):
__annotations__ = {
'box_coder': BoxCoder,
}
def __init__(self):
super(MyModule, self).__init__()
self.box_coder = BoxCoder(50.)
def forward(self, input):
return self.box_coder.decode(input)
model = MyModule()
model.eval()
script_model = torch.jit.freeze(torch.jit.script(model))
inp = torch.randn([4, 4])
output_eager = model(inp)
self.assertEqual(model(inp), script_model(inp))
FileCheck().check_not("GetAttr").run(script_model.graph)

View File

@ -54,6 +54,10 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
case prim::CreateObject: {
createObject(stack, n->output()->type()->expect<ClassType>());
} break;
case prim::GetAttr: {
auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
push(stack, attr);
} break;
case prim::isinstance: {
isinstance(stack, n->tys(attr::types));
} break;