mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind (#149529)
Summary: We need to properly fakify torchbind objects, including the ones in graph module attributes, so the resgitered fake implementation works properly. - _fakify_script_objects in `compile_fx` - Allow fake torchbind objects in `torchbind_constants` Remove `node.meta["unbacked_bindings"]` for `aot_compile` in `compile_fx`. Otherwise `ShapeProp` will fail when trying to resolve the `unbacked_bindings` of `with_effect` tokens. Update `sigrid_transforms_test` to use the latest `torch._inductor.aot_compile` API. Add a test for `Fakify torchbind objects in compile_fx and add tests for SigridTransformsInstanceTorchBind` in `e2e_test`. Test Plan: ``` buck run //caffe2/torch/fb/sparsenn:sigrid_test -- -r test_transform_torch_bind buck run //sigmoid/inference/test:e2e_test_cpu -- -r SigridTransforms buck2 run mode/dev-nosan sigmoid/inference/ts_migration:pt2i_readiness_main -- --model_id 545017754 --test_suite ads_all --mode test_preproc ``` Differential Revision: D70013257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149529 Approved by: https://github.com/angelayi
This commit is contained in:
parent
19b763def1
commit
46dd226702
|
|
@ -518,6 +518,8 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
|||
"takes_foo_list_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor[]");
|
||||
m.def(
|
||||
"takes_foo_tuple_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> (Tensor, Tensor)");
|
||||
m.def(
|
||||
"takes_foo_tensor_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
|
||||
|
||||
m.class_<FooGetterSetter>("_FooGetterSetter")
|
||||
.def(torch::init<int64_t, int64_t>())
|
||||
|
|
@ -701,6 +703,10 @@ std::tuple<at::Tensor, at::Tensor> takes_foo_tuple_return(
|
|||
return std::make_tuple(a, b);
|
||||
}
|
||||
|
||||
at::Tensor takes_foo_tensor_return(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
|
||||
return at::ones({foo->x, foo->y}, at::device(at::kCPU).dtype(at::kInt));
|
||||
}
|
||||
|
||||
void queue_push(c10::intrusive_ptr<TensorQueue> tq, at::Tensor x) {
|
||||
tq->push(x);
|
||||
}
|
||||
|
|
@ -732,6 +738,7 @@ TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) {
|
|||
m.impl("queue_push", queue_push);
|
||||
m.impl("queue_pop", queue_pop);
|
||||
m.impl("queue_size", queue_size);
|
||||
m.impl("takes_foo_tensor_return", takes_foo_tensor_return);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) {
|
||||
|
|
|
|||
|
|
@ -561,6 +561,40 @@ def forward(self, token, obj_attr, x):
|
|||
return (getitem_3, add_1)""", # noqa: B950
|
||||
)
|
||||
|
||||
@parametrize("pre_dispatch", [True, False])
|
||||
def test_custom_obj_unbacked_symint(self, pre_dispatch):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(2, 3)
|
||||
|
||||
def forward(self, x):
|
||||
a = torch.ops._TorchScriptTesting.takes_foo_tensor_return(self.attr, x)
|
||||
return a
|
||||
|
||||
input = torch.ones(2, 3)
|
||||
ep = self._test_export_same_as_eager(
|
||||
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
|
||||
)
|
||||
gm = ep.module()
|
||||
foo_node = next(
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if n.target == torch.ops._TorchScriptTesting.takes_foo_tensor_return.default
|
||||
)
|
||||
unbacked_bindings = foo_node.meta["unbacked_bindings"]
|
||||
self.assertEqual(len(unbacked_bindings), 2)
|
||||
u = next(iter(unbacked_bindings.keys()))
|
||||
path = unbacked_bindings[u]
|
||||
# the unbacked bindings should be CallMethodKey(name='size'), SequenceKey(idx=0)
|
||||
# it should not include the effect token in the path
|
||||
self.assertEqual(
|
||||
type(u).__name__, "Symbol"
|
||||
) # check binding is symbol, not expr
|
||||
self.assertEqual(len(path), 2)
|
||||
self.assertEqual(path[0].name, "size")
|
||||
self.assertEqual(path[1].idx, 0)
|
||||
|
||||
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
||||
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
|
||||
test = self
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import contextlib
|
|||
import inspect
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -533,7 +534,7 @@ def _fakify_module_inputs(
|
|||
@contextlib.contextmanager
|
||||
def _fakify_script_objects(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any],
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[Any, Any],
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1761,6 +1761,10 @@ class AotCodeCompiler:
|
|||
for custom_obj_idx, (name, constant) in enumerate(
|
||||
graph.torchbind_constants.items()
|
||||
):
|
||||
if isinstance(
|
||||
constant, torch._library.fake_class_registry.FakeScriptObject
|
||||
):
|
||||
constant = constant.real_obj
|
||||
assert isinstance(constant, torch._C.ScriptObject)
|
||||
custom_obj_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}"
|
||||
|
||||
|
|
|
|||
|
|
@ -1817,12 +1817,22 @@ def compile_fx(
|
|||
"make sure torch.export() and torch.aot_compile() run on the same device."
|
||||
)
|
||||
inputs_ = fake_inputs # type: ignore[assignment]
|
||||
return compile_fx(
|
||||
model_,
|
||||
inputs_,
|
||||
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
|
||||
decompositions=decompositions,
|
||||
)
|
||||
from torch._export.non_strict_utils import _fakify_script_objects
|
||||
|
||||
fake_mode = detect_fake_mode(inputs_)
|
||||
with _fakify_script_objects(model_, inputs_, {}, fake_mode) as (
|
||||
patched_mod,
|
||||
fake_args,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
):
|
||||
return compile_fx(
|
||||
patched_mod,
|
||||
fake_args,
|
||||
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
|
||||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
recursive_compile_fx = functools.partial(
|
||||
compile_fx,
|
||||
|
|
|
|||
|
|
@ -348,7 +348,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.constants: dict[str, torch.Tensor] = (
|
||||
const_module.constants if const_module else {}
|
||||
)
|
||||
self.torchbind_constants: dict[str, torch._C.ScriptObject] = {}
|
||||
self.torchbind_constants: dict[
|
||||
str, Union[torch._C.ScriptObject, FakeScriptObject]
|
||||
] = {}
|
||||
self.seen_subgraphs: dict[str, ir.Subgraph] = {}
|
||||
self.constant_reprs: dict[str, str] = {}
|
||||
self.removed_operations = OrderedSet[str]()
|
||||
|
|
@ -1231,9 +1233,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.constant_reprs[target] = ""
|
||||
return TorchBindObject(name=target, value=value)
|
||||
elif isinstance(value, FakeScriptObject):
|
||||
self.torchbind_constants[target] = value.real_obj
|
||||
self.torchbind_constants[target] = value
|
||||
self.constant_reprs[target] = ""
|
||||
return TorchBindObject(name=target, value=value.real_obj)
|
||||
return TorchBindObject(name=target, value=value)
|
||||
|
||||
assert isinstance(value, torch.Tensor)
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ from .virtualized import ops, OpsValue, V
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .codegen.cuda.cuda_template import CUDATemplate
|
||||
|
|
@ -5154,7 +5155,9 @@ class ExternKernel(InputsKernel):
|
|||
# strides of inputs and we need to determine accurately what the
|
||||
# output stride will be.
|
||||
example_args: list[
|
||||
Union[torch.Tensor, torch._C.ScriptObject, torch.Generator]
|
||||
Union[
|
||||
torch.Tensor, torch._C.ScriptObject, FakeScriptObject, torch.Generator
|
||||
]
|
||||
] = []
|
||||
|
||||
# We need to retain the constant values of fake tensors that we originally
|
||||
|
|
@ -5171,7 +5174,7 @@ class ExternKernel(InputsKernel):
|
|||
):
|
||||
example_args.append(V.graph.torchbind_constants[x.get_name()])
|
||||
elif isinstance(x, TorchBindObject):
|
||||
example_args.append(x.get_real_obj())
|
||||
example_args.append(x.get_value())
|
||||
elif isinstance(x, torch._inductor.ir.GeneratorState):
|
||||
device_index = x.device.index
|
||||
assert x.device.type == "cuda" and device_index is not None
|
||||
|
|
@ -7798,8 +7801,6 @@ class NonTensorObj(IRNode):
|
|||
|
||||
@ir_dataclass
|
||||
class TorchBindObject(NonTensorObj):
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
|
||||
name: str
|
||||
value: Union[FakeScriptObject, torch.ScriptObject]
|
||||
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
from .triton_bundler import TritonKernelArtifacts
|
||||
|
|
@ -392,7 +393,7 @@ class CompiledFxGraph(OutputCode):
|
|||
mutated_input_idxs: OrderedSet[int]
|
||||
constants: Optional[dict[str, torch.Tensor]]
|
||||
frozen_param_names: dict[str, str]
|
||||
torchbind_constants: dict[str, torch._C.ScriptObject]
|
||||
torchbind_constants: dict[str, torch._C.ScriptObject | FakeScriptObject]
|
||||
output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]]
|
||||
disabled_cudagraphs_reason: Optional[str]
|
||||
metrics_deltas: metrics.CachedMetricsDeltas
|
||||
|
|
|
|||
|
|
@ -71,6 +71,13 @@ def _remove_effect_tokens_from_graph_helper(
|
|||
new_node = ep.graph.call_function(func, node.args[2:], node.kwargs)
|
||||
for k, v in node.meta.items():
|
||||
new_node.meta[k] = v
|
||||
if k == "unbacked_bindings":
|
||||
# Remove the extra layer for effect token
|
||||
old_bindings = new_node.meta[k]
|
||||
new_bindings = {
|
||||
k: path[1:] if path else path for k, path in old_bindings.items()
|
||||
}
|
||||
new_node.meta[k] = new_bindings
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,13 @@ def register_fake_operators():
|
|||
b = foo.add_tensor(a)
|
||||
return (a, b)
|
||||
|
||||
@torch.library.register_fake("_TorchScriptTesting::takes_foo_tensor_return")
|
||||
def meta_takes_foo_tensor_return(foo, x):
|
||||
# This implementation deliberately creates unbacked symint for testing
|
||||
ctx = torch.library.get_ctx()
|
||||
fake_shape = [ctx.new_dynamic_size() for _ in range(2)]
|
||||
return torch.empty(fake_shape, dtype=torch.int, device="cpu")
|
||||
|
||||
torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl(
|
||||
torch._C.DispatchKey.Meta
|
||||
)(meta_takes_foo_list_return)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user