mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Constant prop getattr (#49806)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49806 Fix for https://github.com/pytorch/pytorch/issues/47089 Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D25696791 Pulled By: eellison fbshipit-source-id: 914c17b8effef7f4f341775ac2b8150ee4703efd
This commit is contained in:
parent
268441c7d8
commit
fc559bd6dc
|
|
@ -1331,3 +1331,33 @@ class TestFreezing(JitTestCase):
|
|||
m.eval()
|
||||
with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"):
|
||||
mf = torch._C._freeze_module(m._c)
|
||||
|
||||
|
||||
def test_freeze_non_module_class_getattr(self):
|
||||
class BoxCoder(object):
|
||||
def __init__(self, bbox_xform_clip):
|
||||
# type: (float) -> None
|
||||
self.bbox_xform_clip = bbox_xform_clip
|
||||
|
||||
def decode(self, input):
|
||||
return input * self.bbox_xform_clip
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
__annotations__ = {
|
||||
'box_coder': BoxCoder,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
self.box_coder = BoxCoder(50.)
|
||||
|
||||
def forward(self, input):
|
||||
return self.box_coder.decode(input)
|
||||
|
||||
model = MyModule()
|
||||
model.eval()
|
||||
script_model = torch.jit.freeze(torch.jit.script(model))
|
||||
inp = torch.randn([4, 4])
|
||||
output_eager = model(inp)
|
||||
self.assertEqual(model(inp), script_model(inp))
|
||||
FileCheck().check_not("GetAttr").run(script_model.graph)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,10 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
|
|||
case prim::CreateObject: {
|
||||
createObject(stack, n->output()->type()->expect<ClassType>());
|
||||
} break;
|
||||
case prim::GetAttr: {
|
||||
auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
|
||||
push(stack, attr);
|
||||
} break;
|
||||
case prim::isinstance: {
|
||||
isinstance(stack, n->tys(attr::types));
|
||||
} break;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user