mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Hoist out auxiliary values in optional-typed arguments (#123613)
This fixes #123176, and partially addresses #121814 too. #123176 uses an optional device arg while #121814 uses an optional list arg. For optional arguments that have auxiliary info -- specifically, tuples / lists with their length parameter, and device types with their device index -- we need to hoist out the extra argument. E.g. when passing a device with ID 1, we want to emit ``` auto var_0 = cached_torch_device_type_cpu; aoti_torch_foo(..., &var_0, 1); ``` instead of the (syntactically incorrect) ``` auto var_0 = cached_torch_device_type_cpu,1; aoti_torch_foo(..., &var_0); ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123613 Approved by: https://github.com/desertfire
This commit is contained in:
parent
1970a802b3
commit
178ce1433c
|
|
@ -88,7 +88,6 @@ if config.abi_compatible:
|
|||
"test_qlinear_cpu",
|
||||
"test_qlinear_dequant_promotion_cpu",
|
||||
"test_qlinear_relu_cpu",
|
||||
"test_randn_with_dtype_and_device_cpu",
|
||||
"test_scatter5_cpu",
|
||||
"test_scatter6_cpu",
|
||||
"test_tensor2_cpu",
|
||||
|
|
|
|||
|
|
@ -2093,6 +2093,20 @@ RAIIAtenTensorHandle {output_arg}(
|
|||
return "0" # nullptr is not available in C
|
||||
if not isinstance(type_.getElementType(), torch.TensorType):
|
||||
var_name = f"var_{next(self.arg_var_id)}"
|
||||
if isinstance(
|
||||
type_.getElementType(),
|
||||
(torch.ListType, torch.TupleType, torch.DeviceObjType),
|
||||
):
|
||||
arg_str = self.val_to_arg_str(val)
|
||||
if val is None:
|
||||
return "{arg_str}, 0"
|
||||
else:
|
||||
# For datatypes with auxiliary info, we need to hoist out the extra arguments.
|
||||
# NOTE: This only works if there is one additional argument, though it can easily be generalized.
|
||||
main_value, aux = arg_str.rsplit(", ")
|
||||
self.writeline(f"auto {var_name} = {main_value};")
|
||||
return f"&{var_name}, {aux}"
|
||||
else:
|
||||
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
|
||||
return f"&{var_name}"
|
||||
elif config.c_shim_version == "2":
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ inline std::array<bool, N> pointer_to_list(const int32_t* ptr) {
|
|||
return result;
|
||||
}
|
||||
|
||||
// utility functions to convert a pointer to a list of optional values
|
||||
// Utility function to convert a pointer to an optional list of values
|
||||
template <class T, class U>
|
||||
inline c10::optional<c10::ArrayRef<T>> pointer_to_optional_list(
|
||||
U** ptr,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user