Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind (#149529)

Summary:
We need to properly fakify torchbind objects, including the ones in graph module attributes, so the resgitered fake implementation works properly.

- _fakify_script_objects in `compile_fx`
- Allow fake torchbind objects in `torchbind_constants`

Remove `node.meta["unbacked_bindings"]` for `aot_compile` in `compile_fx`. Otherwise `ShapeProp` will fail when trying to resolve the `unbacked_bindings` of `with_effect` tokens.

Update `sigrid_transforms_test` to use the latest `torch._inductor.aot_compile` API.

Add a test for `Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind` in `e2e_test`.

Test Plan:
```
buck run //caffe2/torch/fb/sparsenn:sigrid_test -- -r test_transform_torch_bind

buck run //sigmoid/inference/test:e2e_test_cpu -- -r SigridTransforms

buck2 run mode/dev-nosan sigmoid/inference/ts_migration:pt2i_readiness_main -- --model_id 545017754 --test_suite ads_all --mode test_preproc

```

Differential Revision: D70013257

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149529
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu 2025-03-21 18:58:23 +00:00 committed by PyTorch MergeBot
parent 19b763def1
commit 46dd226702
10 changed files with 89 additions and 15 deletions

View File

@ -518,6 +518,8 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
"takes_foo_list_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor[]");
m.def(
"takes_foo_tuple_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> (Tensor, Tensor)");
m.def(
"takes_foo_tensor_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
m.class_<FooGetterSetter>("_FooGetterSetter")
.def(torch::init<int64_t, int64_t>())
@ -701,6 +703,10 @@ std::tuple<at::Tensor, at::Tensor> takes_foo_tuple_return(
return std::make_tuple(a, b);
}
at::Tensor takes_foo_tensor_return(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
return at::ones({foo->x, foo->y}, at::device(at::kCPU).dtype(at::kInt));
}
void queue_push(c10::intrusive_ptr<TensorQueue> tq, at::Tensor x) {
tq->push(x);
}
@ -732,6 +738,7 @@ TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) {
m.impl("queue_push", queue_push);
m.impl("queue_pop", queue_pop);
m.impl("queue_size", queue_size);
m.impl("takes_foo_tensor_return", takes_foo_tensor_return);
}
TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) {

View File

@ -561,6 +561,40 @@ def forward(self, token, obj_attr, x):
return (getitem_3, add_1)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_custom_obj_unbacked_symint(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(2, 3)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_tensor_return(self.attr, x)
return a
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
gm = ep.module()
foo_node = next(
n
for n in gm.graph.nodes
if n.target == torch.ops._TorchScriptTesting.takes_foo_tensor_return.default
)
unbacked_bindings = foo_node.meta["unbacked_bindings"]
self.assertEqual(len(unbacked_bindings), 2)
u = next(iter(unbacked_bindings.keys()))
path = unbacked_bindings[u]
# the unbacked bindings should be CallMethodKey(name='size'), SequenceKey(idx=0)
# it should not include the effect token in the path
self.assertEqual(
type(u).__name__, "Symbol"
) # check binding is symbol, not expr
self.assertEqual(len(path), 2)
self.assertEqual(path[0].name, "size")
self.assertEqual(path[1].idx, 0)
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
test = self

View File

@ -3,6 +3,7 @@ import contextlib
import inspect
import logging
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
@ -533,7 +534,7 @@ def _fakify_module_inputs(
@contextlib.contextmanager
def _fakify_script_objects(
mod: torch.nn.Module,
args: tuple[Any],
args: Sequence[Any],
kwargs: dict[Any, Any],
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
):

View File

@ -1761,6 +1761,10 @@ class AotCodeCompiler:
for custom_obj_idx, (name, constant) in enumerate(
graph.torchbind_constants.items()
):
if isinstance(
constant, torch._library.fake_class_registry.FakeScriptObject
):
constant = constant.real_obj
assert isinstance(constant, torch._C.ScriptObject)
custom_obj_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}"

View File

@ -1817,12 +1817,22 @@ def compile_fx(
"make sure torch.export() and torch.aot_compile() run on the same device."
)
inputs_ = fake_inputs # type: ignore[assignment]
return compile_fx(
model_,
inputs_,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
)
from torch._export.non_strict_utils import _fakify_script_objects
fake_mode = detect_fake_mode(inputs_)
with _fakify_script_objects(model_, inputs_, {}, fake_mode) as (
patched_mod,
fake_args,
_,
_,
_,
):
return compile_fx(
patched_mod,
fake_args,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
)
recursive_compile_fx = functools.partial(
compile_fx,

View File

@ -348,7 +348,9 @@ class GraphLowering(torch.fx.Interpreter):
self.constants: dict[str, torch.Tensor] = (
const_module.constants if const_module else {}
)
self.torchbind_constants: dict[str, torch._C.ScriptObject] = {}
self.torchbind_constants: dict[
str, Union[torch._C.ScriptObject, FakeScriptObject]
] = {}
self.seen_subgraphs: dict[str, ir.Subgraph] = {}
self.constant_reprs: dict[str, str] = {}
self.removed_operations = OrderedSet[str]()
@ -1231,9 +1233,9 @@ class GraphLowering(torch.fx.Interpreter):
self.constant_reprs[target] = ""
return TorchBindObject(name=target, value=value)
elif isinstance(value, FakeScriptObject):
self.torchbind_constants[target] = value.real_obj
self.torchbind_constants[target] = value
self.constant_reprs[target] = ""
return TorchBindObject(name=target, value=value.real_obj)
return TorchBindObject(name=target, value=value)
assert isinstance(value, torch.Tensor)
if (

View File

@ -98,6 +98,7 @@ from .virtualized import ops, OpsValue, V
if TYPE_CHECKING:
from torch._library.fake_class_registry import FakeScriptObject
from torch.fx.node import Node
from .codegen.cuda.cuda_template import CUDATemplate
@ -5154,7 +5155,9 @@ class ExternKernel(InputsKernel):
# strides of inputs and we need to determine accurately what the
# output stride will be.
example_args: list[
Union[torch.Tensor, torch._C.ScriptObject, torch.Generator]
Union[
torch.Tensor, torch._C.ScriptObject, FakeScriptObject, torch.Generator
]
] = []
# We need to retain the constant values of fake tensors that we originally
@ -5171,7 +5174,7 @@ class ExternKernel(InputsKernel):
):
example_args.append(V.graph.torchbind_constants[x.get_name()])
elif isinstance(x, TorchBindObject):
example_args.append(x.get_real_obj())
example_args.append(x.get_value())
elif isinstance(x, torch._inductor.ir.GeneratorState):
device_index = x.device.index
assert x.device.type == "cuda" and device_index is not None
@ -7798,8 +7801,6 @@ class NonTensorObj(IRNode):
@ir_dataclass
class TorchBindObject(NonTensorObj):
from torch._library.fake_class_registry import FakeScriptObject
name: str
value: Union[FakeScriptObject, torch.ScriptObject]

View File

@ -62,6 +62,7 @@ if TYPE_CHECKING:
from torch._inductor import metrics
from torch._inductor.graph import GraphLowering
from torch._library.fake_class_registry import FakeScriptObject
from .compile_fx import _CompileFxKwargs
from .triton_bundler import TritonKernelArtifacts
@ -392,7 +393,7 @@ class CompiledFxGraph(OutputCode):
mutated_input_idxs: OrderedSet[int]
constants: Optional[dict[str, torch.Tensor]]
frozen_param_names: dict[str, str]
torchbind_constants: dict[str, torch._C.ScriptObject]
torchbind_constants: dict[str, torch._C.ScriptObject | FakeScriptObject]
output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str]
metrics_deltas: metrics.CachedMetricsDeltas

View File

@ -71,6 +71,13 @@ def _remove_effect_tokens_from_graph_helper(
new_node = ep.graph.call_function(func, node.args[2:], node.kwargs)
for k, v in node.meta.items():
new_node.meta[k] = v
if k == "unbacked_bindings":
# Remove the extra layer for effect token
old_bindings = new_node.meta[k]
new_bindings = {
k: path[1:] if path else path for k, path in old_bindings.items()
}
new_node.meta[k] = new_bindings
node.replace_all_uses_with(new_node)

View File

@ -61,6 +61,13 @@ def register_fake_operators():
b = foo.add_tensor(a)
return (a, b)
@torch.library.register_fake("_TorchScriptTesting::takes_foo_tensor_return")
def meta_takes_foo_tensor_return(foo, x):
# This implementation deliberately creates unbacked symint for testing
ctx = torch.library.get_ctx()
fake_shape = [ctx.new_dynamic_size() for _ in range(2)]
return torch.empty(fake_shape, dtype=torch.int, device="cpu")
torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl(
torch._C.DispatchKey.Meta
)(meta_takes_foo_list_return)