skip torchbind in cosntant folding (#148993)

Summary:
Do not fold torchbind objects in constant folding

Any operation on these torchbind objects can have arbitrary side effects, so we can't effectively constant fold anything torchbind-obj-related anyway.

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r aot_compile_constant_folding
```

Reviewed By: angelayi

Differential Revision: D69946541

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148993
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu 2025-03-12 18:08:08 +00:00 committed by PyTorch MergeBot
parent 923ce10f6c
commit 01e9036bd2
3 changed files with 21 additions and 1 deletions

View File

@ -234,6 +234,13 @@ class TestTorchbind(TestCase):
# TODO: add accuracy test after we support loading and running compiled models with # TODO: add accuracy test after we support loading and running compiled models with
# torchbind objects. # torchbind objects.
@torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True)
def test_torchbind_aot_compile_constant_folding(self):
ep, inputs, _, _ = self.get_exported_model()
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
# TODO: add accuracy test after we support loading and running compiled models with
# torchbind objects.
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -1070,7 +1070,16 @@ class _InProcessFxCompile(FxCompile):
const_kernel_code = None const_kernel_code = None
if aot_mode and config.aot_inductor.use_runtime_constant_folding: if aot_mode and config.aot_inductor.use_runtime_constant_folding:
const_gm, const_output_index = split_const_gm(gm) # torchbind objects have name that starts with _torchbind_obj
# See caffe2/torch/fx/_symbolic_trace.py?lines=406
# We don't use node.meta["val"] because we don't typically
# attach meta["val"] for get_attr nodes.
const_gm, const_output_index = split_const_gm(
gm,
skip_folding_node_fn=lambda node: node.op == "get_attr"
and isinstance(node.target, str)
and node.target.startswith("_torchbind_obj"),
)
const_graph = GraphLowering( const_graph = GraphLowering(
const_gm, const_gm,

View File

@ -233,6 +233,10 @@ class ConstantFolder(torch.fx.Interpreter):
return self.unknown_value return self.unknown_value
out = self._deduce_value(node) out = self._deduce_value(node)
if isinstance(out, torch._C.ScriptObject):
return out
if out == self.unknown_value: if out == self.unknown_value:
return self.unknown_value return self.unknown_value