[custom op] fix inductor cpp codegen when returning a list of single tensor (#147649)

For a custom op that returns a list of a single tensor with unbacked symint shape:
```python

@torch.library.custom_op(
    "aoti_custom_ops::fn_ret_list_of_single_tensor", mutates_args={}
)
def fn_ret_list_of_single_tensor(x: torch.Tensor) -> list[torch.Tensor]:
    s = x.sum().to(torch.int64)
    return [torch.randn(s.item())]

@fn_ret_list_of_single_tensor.register_fake
def _(x):
    ctx = torch._custom_op.impl.get_ctx()
    i0 = ctx.new_dynamic_size()
    return [torch.randn(i0)]
```

Before the fix, we have the following error:
```
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: error: type/value mismatch at argument 1 in template parameter list for ‘template<class _Tp, class ... _Types> constexpr const _Tp& std::get(const std::variant<_Types ...>&)’
  456 |     auto u0 = std::get<0>(buf1).size(0);
      |               ~~~~~~~~~~~^~~~~~
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: note:   expected a type, got ‘0’
In file included from /data/users/yidi/pytorch/torch/include/c10/util/Exception.h:14,
                 from /data/users/yidi/pytorch/torch/include/c10/core/ScalarType.h:5,
                 from /data/users/yidi/pytorch/torch/include/ATen/AccumulateType.h:4,
                 from /data/users/yidi/pytorch/torch/include/ATen/native/Math.h:3,
                 from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec_base.h:31,
                 from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec512/vec512.h:8,
                 from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec.h:4,
                 from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/functional_base.h:6,
                 from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/functional.h:3,
                 from /tmp/tmp5iikarn2/3b/c3bi5gk6mslf6u4iaqafhxm64z6u65e3eain4xlary5blqnvv6xx.h:39,
                 from /tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:366:
/usr/include/c++/11/variant:1145:27: note: candidate: ‘template<class _Tp, class ... _Types> constexpr const _Tp&& std::get(const std::variant<_Types ...>&&)’
 1145 |     constexpr const _Tp&& get(const variant<_Types...>&& __v)
      |                           ^~~
/usr/include/c++/11/variant:1145:27: note:   template argument deduction/substitution failed:
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: error: type/value mismatch at argument 1 in template parameter list for ‘template<class _Tp, class ... _Types> constexpr const _Tp&& std::get(const std::variant<_Types ...>&&)’
  456 |     auto u0 = std::get<0>(buf1).size(0);
      |               ~~~~~~~~~~~^~~~~~
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: note:   expected a type, got ‘0’
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147649
Approved by: https://github.com/angelayi
ghstack dependencies: #147130
This commit is contained in:
Yidi Wu 2025-02-24 13:34:32 -08:00 committed by PyTorch MergeBot
parent 824474cb35
commit adf0f4ffd2
2 changed files with 56 additions and 1 deletions

View File

@ -79,6 +79,34 @@ def fn_with_incorrect_optional_tensor_fake(
return x + y + z
@torch.library.custom_op(
"aoti_custom_ops::fn_ret_list_of_single_tensor", mutates_args={}
)
def fn_ret_list_of_single_tensor(x: torch.Tensor) -> list[torch.Tensor]:
s = x.sum().to(torch.int64)
return [torch.randn(s.item())]
@fn_ret_list_of_single_tensor.register_fake
def _(x):
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.new_dynamic_size()
return [torch.randn(i0)]
@torch.library.custom_op("aoti_custom_ops::fn_ret_single_tensor", mutates_args={})
def fn_ret_single_tensor(x: torch.Tensor) -> torch.Tensor:
s = x.sum().to(torch.int64)
return torch.randn(s.item())
@fn_ret_single_tensor.register_fake
def _(x):
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.new_dynamic_size()
return torch.randn(i0)
class AOTInductorTestsTemplate:
def test_custom_op_add(self) -> None:
class M(torch.nn.Module):
@ -252,6 +280,24 @@ class AOTInductorTestsTemplate:
self.check_model(m, args)
def test_custom_op_return_list_of_single_tensor(self) -> None:
class Model(torch.nn.Module):
def forward(self, x):
return torch.ops.aoti_custom_ops.fn_ret_list_of_single_tensor(x)[0] + 1
m = Model().to(device=self.device)
args = (torch.randn(3, 4),)
self.check_model(m, args)
def test_custom_op_return_single_tensor(self) -> None:
class Model(torch.nn.Module):
def forward(self, x):
return torch.ops.aoti_custom_ops.fn_ret_single_tensor(x) + 1
m = Model().to(device=self.device)
args = (torch.randn(3, 4),)
self.check_model(m, args)
@unittest.skipIf(IS_FBCODE, "FbProxyExecutor doesn't have these error msgs")
def test_incorrect_custom_op_schema(self):
class M(torch.nn.Module):

View File

@ -6479,7 +6479,16 @@ class FallbackKernel(ExternKernelAlloc):
# individual output arguments are bound by
# generate_c_shim_fallback_kernel
if len(self.outputs) == 1:
return go(self.outputs[0].get_name(), keypath)
out = self.outputs[0]
# When fallback kernel returns a list consisting of a single tensor,
# the output is represented as a MultiOutput with non empty indices.
# In this case, we strip the first key path away.
return go(
self.outputs[0].get_name(),
keypath[1:]
if isinstance(out, MultiOutput) and len(out.indices) != 0
else keypath,
)
else:
assert isinstance(keypath[0], pytree.SequenceKey)
return go(self.outputs[keypath[0].idx].get_name(), keypath[1:])