Hoist out auxiliary values in optional-typed arguments (#123613)

This fixes #123176, and partially addresses #121814 too. #123176 uses an
optional device arg while #121814 uses an optional list arg.

For optional arguments that have auxiliary info -- specifically, tuples
/ lists with their length parameter, and device types with their device
index -- we need to hoist out the extra argument. E.g. when passing a
device with ID 1, we want to emit

```
auto var_0 = cached_torch_device_type_cpu;
aoti_torch_foo(..., &var_0, 1);
```

instead of the (syntactically incorrect)

```
auto var_0 = cached_torch_device_type_cpu,1;
aoti_torch_foo(..., &var_0);
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123613
Approved by: https://github.com/desertfire
This commit is contained in:
Jez Ng 2024-04-09 10:09:08 -07:00 committed by PyTorch MergeBot
parent 1970a802b3
commit 178ce1433c
3 changed files with 18 additions and 5 deletions

View File

@ -88,7 +88,6 @@ if config.abi_compatible:
"test_qlinear_cpu",
"test_qlinear_dequant_promotion_cpu",
"test_qlinear_relu_cpu",
"test_randn_with_dtype_and_device_cpu",
"test_scatter5_cpu",
"test_scatter6_cpu",
"test_tensor2_cpu",

View File

@ -2093,6 +2093,20 @@ RAIIAtenTensorHandle {output_arg}(
return "0" # nullptr is not available in C
if not isinstance(type_.getElementType(), torch.TensorType):
var_name = f"var_{next(self.arg_var_id)}"
if isinstance(
type_.getElementType(),
(torch.ListType, torch.TupleType, torch.DeviceObjType),
):
arg_str = self.val_to_arg_str(val)
if val is None:
return "{arg_str}, 0"
else:
# For datatypes with auxiliary info, we need to hoist out the extra arguments.
# NOTE: This only works if there is one additional argument, though it can easily be generalized.
main_value, aux = arg_str.rsplit(", ")
self.writeline(f"auto {var_name} = {main_value};")
return f"&{var_name}, {aux}"
else:
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
return f"&{var_name}"
elif config.c_shim_version == "2":

View File

@ -141,7 +141,7 @@ inline std::array<bool, N> pointer_to_list(const int32_t* ptr) {
return result;
}
// utility functions to convert a pointer to a list of optional values
// Utility function to convert a pointer to an optional list of values
template <class T, class U>
inline c10::optional<c10::ArrayRef<T>> pointer_to_optional_list(
U** ptr,