mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[custom op] fix inductor cpp codegen when returning a list of single tensor (#147649)
For a custom op that returns a list of a single tensor with unbacked symint shape:
```python
@torch.library.custom_op(
"aoti_custom_ops::fn_ret_list_of_single_tensor", mutates_args={}
)
def fn_ret_list_of_single_tensor(x: torch.Tensor) -> list[torch.Tensor]:
s = x.sum().to(torch.int64)
return [torch.randn(s.item())]
@fn_ret_list_of_single_tensor.register_fake
def _(x):
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.new_dynamic_size()
return [torch.randn(i0)]
```
Before the fix, we have the following error:
```
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: error: type/value mismatch at argument 1 in template parameter list for ‘template<class _Tp, class ... _Types> constexpr const _Tp& std::get(const std::variant<_Types ...>&)’
456 | auto u0 = std::get<0>(buf1).size(0);
| ~~~~~~~~~~~^~~~~~
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: note: expected a type, got ‘0’
In file included from /data/users/yidi/pytorch/torch/include/c10/util/Exception.h:14,
from /data/users/yidi/pytorch/torch/include/c10/core/ScalarType.h:5,
from /data/users/yidi/pytorch/torch/include/ATen/AccumulateType.h:4,
from /data/users/yidi/pytorch/torch/include/ATen/native/Math.h:3,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec_base.h:31,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec512/vec512.h:8,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec.h:4,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/functional_base.h:6,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/functional.h:3,
from /tmp/tmp5iikarn2/3b/c3bi5gk6mslf6u4iaqafhxm64z6u65e3eain4xlary5blqnvv6xx.h:39,
from /tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:366:
/usr/include/c++/11/variant:1145:27: note: candidate: ‘template<class _Tp, class ... _Types> constexpr const _Tp&& std::get(const std::variant<_Types ...>&&)’
1145 | constexpr const _Tp&& get(const variant<_Types...>&& __v)
| ^~~
/usr/include/c++/11/variant:1145:27: note: template argument deduction/substitution failed:
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: error: type/value mismatch at argument 1 in template parameter list for ‘template<class _Tp, class ... _Types> constexpr const _Tp&& std::get(const std::variant<_Types ...>&&)’
456 | auto u0 = std::get<0>(buf1).size(0);
| ~~~~~~~~~~~^~~~~~
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: note: expected a type, got ‘0’
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147649
Approved by: https://github.com/angelayi
ghstack dependencies: #147130
This commit is contained in:
parent
824474cb35
commit
adf0f4ffd2
|
|
@ -79,6 +79,34 @@ def fn_with_incorrect_optional_tensor_fake(
|
|||
return x + y + z
|
||||
|
||||
|
||||
@torch.library.custom_op(
|
||||
"aoti_custom_ops::fn_ret_list_of_single_tensor", mutates_args={}
|
||||
)
|
||||
def fn_ret_list_of_single_tensor(x: torch.Tensor) -> list[torch.Tensor]:
|
||||
s = x.sum().to(torch.int64)
|
||||
return [torch.randn(s.item())]
|
||||
|
||||
|
||||
@fn_ret_list_of_single_tensor.register_fake
|
||||
def _(x):
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
i0 = ctx.new_dynamic_size()
|
||||
return [torch.randn(i0)]
|
||||
|
||||
|
||||
@torch.library.custom_op("aoti_custom_ops::fn_ret_single_tensor", mutates_args={})
|
||||
def fn_ret_single_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
s = x.sum().to(torch.int64)
|
||||
return torch.randn(s.item())
|
||||
|
||||
|
||||
@fn_ret_single_tensor.register_fake
|
||||
def _(x):
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
i0 = ctx.new_dynamic_size()
|
||||
return torch.randn(i0)
|
||||
|
||||
|
||||
class AOTInductorTestsTemplate:
|
||||
def test_custom_op_add(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
|
|
@ -252,6 +280,24 @@ class AOTInductorTestsTemplate:
|
|||
|
||||
self.check_model(m, args)
|
||||
|
||||
def test_custom_op_return_list_of_single_tensor(self) -> None:
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.ops.aoti_custom_ops.fn_ret_list_of_single_tensor(x)[0] + 1
|
||||
|
||||
m = Model().to(device=self.device)
|
||||
args = (torch.randn(3, 4),)
|
||||
self.check_model(m, args)
|
||||
|
||||
def test_custom_op_return_single_tensor(self) -> None:
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.ops.aoti_custom_ops.fn_ret_single_tensor(x) + 1
|
||||
|
||||
m = Model().to(device=self.device)
|
||||
args = (torch.randn(3, 4),)
|
||||
self.check_model(m, args)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "FbProxyExecutor doesn't have these error msgs")
|
||||
def test_incorrect_custom_op_schema(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -6479,7 +6479,16 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
# individual output arguments are bound by
|
||||
# generate_c_shim_fallback_kernel
|
||||
if len(self.outputs) == 1:
|
||||
return go(self.outputs[0].get_name(), keypath)
|
||||
out = self.outputs[0]
|
||||
# When fallback kernel returns a list consisting of a single tensor,
|
||||
# the output is represented as a MultiOutput with non empty indices.
|
||||
# In this case, we strip the first key path away.
|
||||
return go(
|
||||
self.outputs[0].get_name(),
|
||||
keypath[1:]
|
||||
if isinstance(out, MultiOutput) and len(out.indices) != 0
|
||||
else keypath,
|
||||
)
|
||||
else:
|
||||
assert isinstance(keypath[0], pytree.SequenceKey)
|
||||
return go(self.outputs[keypath[0].idx].get_name(), keypath[1:])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user