Fix with effect lowering for list return type (#149510)

Summary: - For `torch.ops.higher_order.with_effects`'s lowering, we should not extract the items out of an list (i.e. `*result` vs `result`). The `get_attr` nodes consider the result to be in the list format.

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r test_torchbind_aot_compile

buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r list_return

buck run //caffe2/torch/fb/sparsenn:sigrid_test -- -r test_transform_torch_bind # tested together with D70013257

buck run fbcode//mode/dev-nosan //caffe2/test:test_export  -- -r test_custom_obj
```

Reviewed By: angelayi

Differential Revision: D71346024

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149510
Approved by: https://github.com/zou3519
This commit is contained in:
Shangdi Yu 2025-03-19 19:35:05 +00:00 committed by PyTorch MergeBot
parent 842a072fd3
commit 05fee772e5
3 changed files with 33 additions and 1 deletions

View File

@ -286,6 +286,30 @@ class TestTorchbind(TestCase):
# TODO: add accuracy test after we support loading and running compiled models with # TODO: add accuracy test after we support loading and running compiled models with
# torchbind objects. # torchbind objects.
def test_torchbind_list_return_aot_compile(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
y = a[0] + a[1] + a[2]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return x + b
m = M()
inputs = (torch.ones(2, 3),)
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
with enable_torchbind_tracing():
ep = torch.export.export(m, inputs, strict=False)
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
# TODO: add accuracy test after we support loading and running compiled models with
# torchbind objects.
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -146,6 +146,12 @@ def with_effects_dense(
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
out = op(*args, **kwargs) out = op(*args, **kwargs)
new_token = new_token_tensor() new_token = new_token_tensor()
# [NOTE: with_effects return type]
# Note that we should only do *out for tuple type, but not list type.
# This is to match the schema of the op.
# For tuple output, the length of schema output is the same as the length of out.
# For list output, the length of schema output is 1 (e.g. Tensor[]) regardless of the
# length of the list.
if isinstance(out, tuple): if isinstance(out, tuple):
return (new_token, *out) return (new_token, *out)
return (new_token, out) return (new_token, out)

View File

@ -6954,7 +6954,9 @@ def with_effects(token, op, *args, **kwargs):
return (effectful_kernel,) return (effectful_kernel,)
result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result) result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result)
if not isinstance(result, (list, tuple)): # See [NOTE: with_effects return type]
# Only return `result` if it is a tuple, not list.
if not isinstance(result, tuple):
return (effectful_kernel, result) return (effectful_kernel, result)
else: else:
return (effectful_kernel, *result) return (effectful_kernel, *result)