mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
skip torchbind in cosntant folding (#148993)
Summary: Do not fold torchbind objects in constant folding Any operation on these torchbind objects can have arbitrary side effects, so we can't effectively constant fold anything torchbind-obj-related anyway. Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r aot_compile_constant_folding ``` Reviewed By: angelayi Differential Revision: D69946541 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148993 Approved by: https://github.com/angelayi
This commit is contained in:
parent
923ce10f6c
commit
01e9036bd2
|
|
@ -234,6 +234,13 @@ class TestTorchbind(TestCase):
|
||||||
# TODO: add accuracy test after we support loading and running compiled models with
|
# TODO: add accuracy test after we support loading and running compiled models with
|
||||||
# torchbind objects.
|
# torchbind objects.
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True)
|
||||||
|
def test_torchbind_aot_compile_constant_folding(self):
|
||||||
|
ep, inputs, _, _ = self.get_exported_model()
|
||||||
|
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
|
||||||
|
# TODO: add accuracy test after we support loading and running compiled models with
|
||||||
|
# torchbind objects.
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1070,7 +1070,16 @@ class _InProcessFxCompile(FxCompile):
|
||||||
const_kernel_code = None
|
const_kernel_code = None
|
||||||
|
|
||||||
if aot_mode and config.aot_inductor.use_runtime_constant_folding:
|
if aot_mode and config.aot_inductor.use_runtime_constant_folding:
|
||||||
const_gm, const_output_index = split_const_gm(gm)
|
# torchbind objects have name that starts with _torchbind_obj
|
||||||
|
# See caffe2/torch/fx/_symbolic_trace.py?lines=406
|
||||||
|
# We don't use node.meta["val"] because we don't typically
|
||||||
|
# attach meta["val"] for get_attr nodes.
|
||||||
|
const_gm, const_output_index = split_const_gm(
|
||||||
|
gm,
|
||||||
|
skip_folding_node_fn=lambda node: node.op == "get_attr"
|
||||||
|
and isinstance(node.target, str)
|
||||||
|
and node.target.startswith("_torchbind_obj"),
|
||||||
|
)
|
||||||
|
|
||||||
const_graph = GraphLowering(
|
const_graph = GraphLowering(
|
||||||
const_gm,
|
const_gm,
|
||||||
|
|
|
||||||
|
|
@ -233,6 +233,10 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||||
return self.unknown_value
|
return self.unknown_value
|
||||||
|
|
||||||
out = self._deduce_value(node)
|
out = self._deduce_value(node)
|
||||||
|
|
||||||
|
if isinstance(out, torch._C.ScriptObject):
|
||||||
|
return out
|
||||||
|
|
||||||
if out == self.unknown_value:
|
if out == self.unknown_value:
|
||||||
return self.unknown_value
|
return self.unknown_value
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user