diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index d948bed8172..248b0324cd8 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -234,6 +234,13 @@ class TestTorchbind(TestCase): # TODO: add accuracy test after we support loading and running compiled models with # torchbind objects. + @torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True) + def test_torchbind_aot_compile_constant_folding(self): + ep, inputs, _, _ = self.get_exported_model() + aot_compile(ep.module(), inputs, options={"aot_inductor.package": True}) + # TODO: add accuracy test after we support loading and running compiled models with + # torchbind objects. + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 34930377158..339c93e20a3 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1070,7 +1070,16 @@ class _InProcessFxCompile(FxCompile): const_kernel_code = None if aot_mode and config.aot_inductor.use_runtime_constant_folding: - const_gm, const_output_index = split_const_gm(gm) + # torchbind objects have name that starts with _torchbind_obj + # See caffe2/torch/fx/_symbolic_trace.py?lines=406 + # We don't use node.meta["val"] because we don't typically + # attach meta["val"] for get_attr nodes. + const_gm, const_output_index = split_const_gm( + gm, + skip_folding_node_fn=lambda node: node.op == "get_attr" + and isinstance(node.target, str) + and node.target.startswith("_torchbind_obj"), + ) const_graph = GraphLowering( const_gm, diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 1972bcc3583..93237d67e01 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -233,6 +233,10 @@ class ConstantFolder(torch.fx.Interpreter): return self.unknown_value out = self._deduce_value(node) + + if isinstance(out, torch._C.ScriptObject): + return out + if out == self.unknown_value: return self.unknown_value