mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
skip torchbind in cosntant folding (#148993)
Summary: Do not fold torchbind objects in constant folding Any operation on these torchbind objects can have arbitrary side effects, so we can't effectively constant fold anything torchbind-obj-related anyway. Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r aot_compile_constant_folding ``` Reviewed By: angelayi Differential Revision: D69946541 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148993 Approved by: https://github.com/angelayi
This commit is contained in:
parent
923ce10f6c
commit
01e9036bd2
|
|
@ -234,6 +234,13 @@ class TestTorchbind(TestCase):
|
|||
# TODO: add accuracy test after we support loading and running compiled models with
|
||||
# torchbind objects.
|
||||
|
||||
@torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True)
|
||||
def test_torchbind_aot_compile_constant_folding(self):
|
||||
ep, inputs, _, _ = self.get_exported_model()
|
||||
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
|
||||
# TODO: add accuracy test after we support loading and running compiled models with
|
||||
# torchbind objects.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1070,7 +1070,16 @@ class _InProcessFxCompile(FxCompile):
|
|||
const_kernel_code = None
|
||||
|
||||
if aot_mode and config.aot_inductor.use_runtime_constant_folding:
|
||||
const_gm, const_output_index = split_const_gm(gm)
|
||||
# torchbind objects have name that starts with _torchbind_obj
|
||||
# See caffe2/torch/fx/_symbolic_trace.py?lines=406
|
||||
# We don't use node.meta["val"] because we don't typically
|
||||
# attach meta["val"] for get_attr nodes.
|
||||
const_gm, const_output_index = split_const_gm(
|
||||
gm,
|
||||
skip_folding_node_fn=lambda node: node.op == "get_attr"
|
||||
and isinstance(node.target, str)
|
||||
and node.target.startswith("_torchbind_obj"),
|
||||
)
|
||||
|
||||
const_graph = GraphLowering(
|
||||
const_gm,
|
||||
|
|
|
|||
|
|
@ -233,6 +233,10 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||
return self.unknown_value
|
||||
|
||||
out = self._deduce_value(node)
|
||||
|
||||
if isinstance(out, torch._C.ScriptObject):
|
||||
return out
|
||||
|
||||
if out == self.unknown_value:
|
||||
return self.unknown_value
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user