mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[AOTI] Support InplaceBernoulliFallback in the ABI-compatible codegen (#126183)
Summary: Update the torchgen rule for inplace ops like bernoulli_, and update InplaceBernoulliFallback to codegen in the ABI-compatible mode. Fixes https://github.com/pytorch/pytorch/issues/121809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126183 Approved by: https://github.com/angelayi ghstack dependencies: #126181, #126182
This commit is contained in:
parent
5792bc3c3e
commit
0332b5812e
|
|
@ -71,7 +71,6 @@ test_failures_cpp_wrapper = {
|
||||||
|
|
||||||
if config.abi_compatible:
|
if config.abi_compatible:
|
||||||
xfail_list = [
|
xfail_list = [
|
||||||
"test_bernoulli1_cpu", # cpp fallback op naming issue
|
|
||||||
"test_conv2d_binary_inplace_fusion_failed_cpu",
|
"test_conv2d_binary_inplace_fusion_failed_cpu",
|
||||||
"test_conv2d_binary_inplace_fusion_pass_cpu",
|
"test_conv2d_binary_inplace_fusion_pass_cpu",
|
||||||
"test_dynamic_qlinear_cpu",
|
"test_dynamic_qlinear_cpu",
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,6 @@ if TEST_WITH_ROCM:
|
||||||
|
|
||||||
if config.abi_compatible:
|
if config.abi_compatible:
|
||||||
xfail_list = [
|
xfail_list = [
|
||||||
"test_bernoulli1_cuda", # cpp fallback op naming issue
|
|
||||||
"test_profiler_mark_wrapper_call_cuda",
|
"test_profiler_mark_wrapper_call_cuda",
|
||||||
"test_scaled_dot_product_attention_cuda_dynamic_shapes",
|
"test_scaled_dot_product_attention_cuda_dynamic_shapes",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -4784,9 +4784,17 @@ class InplaceBernoulliFallback(ExternKernel):
|
||||||
|
|
||||||
def codegen(self, wrapper):
|
def codegen(self, wrapper):
|
||||||
(x,) = (t.codegen_reference() for t in self.inputs)
|
(x,) = (t.codegen_reference() for t in self.inputs)
|
||||||
wrapper.writeline(
|
|
||||||
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
|
if V.graph.cpp_wrapper and config.abi_compatible:
|
||||||
)
|
# Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here,
|
||||||
|
# which needs to be explicitly generated for cpp wrapper
|
||||||
|
wrapper.writeline(
|
||||||
|
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wrapper.writeline(
|
||||||
|
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
|
||||||
|
)
|
||||||
|
|
||||||
def should_allocate(self):
|
def should_allocate(self):
|
||||||
return False
|
return False
|
||||||
|
|
@ -4797,20 +4805,19 @@ class InplaceBernoulliFallback(ExternKernel):
|
||||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
def __init__(self, x, *constant_args):
|
def __init__(self, op_overload, x, *constant_args):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
None,
|
None,
|
||||||
NoneLayout(x.get_device()), # type: ignore[arg-type]
|
NoneLayout(x.get_device()), # type: ignore[arg-type]
|
||||||
self.unwrap_storage([x]),
|
self.unwrap_storage([x]),
|
||||||
constant_args,
|
constant_args,
|
||||||
|
op_overload=op_overload,
|
||||||
)
|
)
|
||||||
self.name = V.graph.register_buffer(self)
|
self.name = V.graph.register_buffer(self)
|
||||||
self.python_kernel_name = "aten.bernoulli_"
|
self.python_kernel_name = "aten.bernoulli_"
|
||||||
self.cpp_kernel_name = (
|
if not config.abi_compatible:
|
||||||
"aoti_torch_bernoulli_"
|
# TODO: this should be simplified once we switch to ABI-compatible only
|
||||||
if config.abi_compatible
|
self.cpp_kernel_name = "at::native::bernoulli_"
|
||||||
else "at::native::bernoulli_"
|
|
||||||
)
|
|
||||||
mark_node_as_mutating(self, x)
|
mark_node_as_mutating(self, x)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1788,7 +1788,12 @@ def bernoulli_(x, *args):
|
||||||
"cpu"
|
"cpu"
|
||||||
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
|
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
|
||||||
x.realize()
|
x.realize()
|
||||||
ir.InplaceBernoulliFallback(x, *args)
|
op_overload = (
|
||||||
|
aten.bernoulli_.float
|
||||||
|
if len(args) == 0 or isinstance(args[0], float)
|
||||||
|
else aten.bernoulli_.Tensor
|
||||||
|
)
|
||||||
|
ir.InplaceBernoulliFallback(op_overload, x, *args)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d(AtenTensorHandle self
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
|
||||||
|
|
@ -105,8 +107,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dty
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d(AtenTensorHandle sel
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
|
||||||
|
|
@ -112,8 +114,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dt
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,8 @@ inductor_fallback_ops = {
|
||||||
"aten.avg_pool2d.default",
|
"aten.avg_pool2d.default",
|
||||||
"aten.avg_pool3d_backward.default",
|
"aten.avg_pool3d_backward.default",
|
||||||
"aten.avg_pool3d.default",
|
"aten.avg_pool3d.default",
|
||||||
|
"aten.bernoulli_.float",
|
||||||
|
"aten.bernoulli_.Tensor",
|
||||||
"aten.bmm.out",
|
"aten.bmm.out",
|
||||||
"aten.bucketize.Tensor",
|
"aten.bucketize.Tensor",
|
||||||
"aten.cat.default",
|
"aten.cat.default",
|
||||||
|
|
|
||||||
|
|
@ -249,18 +249,18 @@ def gen_declaration_and_definition(
|
||||||
return declaration_definition_cache[(func_name, device, backend_call)]
|
return declaration_definition_cache[(func_name, device, backend_call)]
|
||||||
|
|
||||||
if schema.is_out_fn():
|
if schema.is_out_fn():
|
||||||
# out_variant has out arguments in the front, and it's ok to ignore return value
|
# out_variant has out arguments in the front, and it's ok to ignore return values
|
||||||
# because C shim functions only return AOTITorchError
|
# because C shim functions only return AOTITorchError
|
||||||
# Somehow at::native out-variant functions have out arguments in the back
|
|
||||||
args, callsite_exprs = gen_arguments(
|
args, callsite_exprs = gen_arguments(
|
||||||
[*schema.arguments.flat_non_out, *schema.arguments.out]
|
[*schema.arguments.out, *schema.arguments.flat_non_out]
|
||||||
if "at::native" in backend_call
|
|
||||||
else [*schema.arguments.out, *schema.arguments.flat_non_out],
|
|
||||||
)
|
)
|
||||||
ret_assignments: List[str] = []
|
ret_assignments: List[str] = []
|
||||||
else:
|
else:
|
||||||
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
||||||
ret_declarations, ret_assignments = gen_returns(schema)
|
# ignore return values for inplace ops
|
||||||
|
ret_declarations, ret_assignments = (
|
||||||
|
([], []) if schema.name.name.inplace else gen_returns(schema)
|
||||||
|
)
|
||||||
args.extend(ret_declarations)
|
args.extend(ret_declarations)
|
||||||
|
|
||||||
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
|
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user