mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTI] Support InplaceBernoulliFallback in the ABI-compatible codegen (#126183)
Summary: Update the torchgen rule for inplace ops like bernoulli_, and update InplaceBernoulliFallback to codegen in the ABI-compatible mode. Fixes https://github.com/pytorch/pytorch/issues/121809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126183 Approved by: https://github.com/angelayi ghstack dependencies: #126181, #126182
This commit is contained in:
parent
5792bc3c3e
commit
0332b5812e
|
|
@ -71,7 +71,6 @@ test_failures_cpp_wrapper = {
|
|||
|
||||
if config.abi_compatible:
|
||||
xfail_list = [
|
||||
"test_bernoulli1_cpu", # cpp fallback op naming issue
|
||||
"test_conv2d_binary_inplace_fusion_failed_cpu",
|
||||
"test_conv2d_binary_inplace_fusion_pass_cpu",
|
||||
"test_dynamic_qlinear_cpu",
|
||||
|
|
|
|||
|
|
@ -97,7 +97,6 @@ if TEST_WITH_ROCM:
|
|||
|
||||
if config.abi_compatible:
|
||||
xfail_list = [
|
||||
"test_bernoulli1_cuda", # cpp fallback op naming issue
|
||||
"test_profiler_mark_wrapper_call_cuda",
|
||||
"test_scaled_dot_product_attention_cuda_dynamic_shapes",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4784,9 +4784,17 @@ class InplaceBernoulliFallback(ExternKernel):
|
|||
|
||||
def codegen(self, wrapper):
|
||||
(x,) = (t.codegen_reference() for t in self.inputs)
|
||||
wrapper.writeline(
|
||||
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
|
||||
)
|
||||
|
||||
if V.graph.cpp_wrapper and config.abi_compatible:
|
||||
# Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here,
|
||||
# which needs to be explicitly generated for cpp wrapper
|
||||
wrapper.writeline(
|
||||
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}"
|
||||
)
|
||||
else:
|
||||
wrapper.writeline(
|
||||
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
|
||||
)
|
||||
|
||||
def should_allocate(self):
|
||||
return False
|
||||
|
|
@ -4797,20 +4805,19 @@ class InplaceBernoulliFallback(ExternKernel):
|
|||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
|
||||
def __init__(self, x, *constant_args):
|
||||
def __init__(self, op_overload, x, *constant_args):
|
||||
super().__init__(
|
||||
None,
|
||||
NoneLayout(x.get_device()), # type: ignore[arg-type]
|
||||
self.unwrap_storage([x]),
|
||||
constant_args,
|
||||
op_overload=op_overload,
|
||||
)
|
||||
self.name = V.graph.register_buffer(self)
|
||||
self.python_kernel_name = "aten.bernoulli_"
|
||||
self.cpp_kernel_name = (
|
||||
"aoti_torch_bernoulli_"
|
||||
if config.abi_compatible
|
||||
else "at::native::bernoulli_"
|
||||
)
|
||||
if not config.abi_compatible:
|
||||
# TODO: this should be simplified once we switch to ABI-compatible only
|
||||
self.cpp_kernel_name = "at::native::bernoulli_"
|
||||
mark_node_as_mutating(self, x)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1788,7 +1788,12 @@ def bernoulli_(x, *args):
|
|||
"cpu"
|
||||
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
|
||||
x.realize()
|
||||
ir.InplaceBernoulliFallback(x, *args)
|
||||
op_overload = (
|
||||
aten.bernoulli_.float
|
||||
if len(args) == 0 or isinstance(args[0], float)
|
||||
else aten.bernoulli_.Tensor
|
||||
)
|
||||
ir.InplaceBernoulliFallback(op_overload, x, *args)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d(AtenTensorHandle self
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
|
||||
|
|
@ -105,8 +107,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dty
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
|
||||
|
|
|
|||
|
|
@ -55,6 +55,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d(AtenTensorHandle sel
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0);
|
||||
|
|
@ -112,8 +114,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dt
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self);
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ inductor_fallback_ops = {
|
|||
"aten.avg_pool2d.default",
|
||||
"aten.avg_pool3d_backward.default",
|
||||
"aten.avg_pool3d.default",
|
||||
"aten.bernoulli_.float",
|
||||
"aten.bernoulli_.Tensor",
|
||||
"aten.bmm.out",
|
||||
"aten.bucketize.Tensor",
|
||||
"aten.cat.default",
|
||||
|
|
|
|||
|
|
@ -249,18 +249,18 @@ def gen_declaration_and_definition(
|
|||
return declaration_definition_cache[(func_name, device, backend_call)]
|
||||
|
||||
if schema.is_out_fn():
|
||||
# out_variant has out arguments in the front, and it's ok to ignore return value
|
||||
# out_variant has out arguments in the front, and it's ok to ignore return values
|
||||
# because C shim functions only return AOTITorchError
|
||||
# Somehow at::native out-variant functions have out arguments in the back
|
||||
args, callsite_exprs = gen_arguments(
|
||||
[*schema.arguments.flat_non_out, *schema.arguments.out]
|
||||
if "at::native" in backend_call
|
||||
else [*schema.arguments.out, *schema.arguments.flat_non_out],
|
||||
[*schema.arguments.out, *schema.arguments.flat_non_out]
|
||||
)
|
||||
ret_assignments: List[str] = []
|
||||
else:
|
||||
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
||||
ret_declarations, ret_assignments = gen_returns(schema)
|
||||
# ignore return values for inplace ops
|
||||
ret_declarations, ret_assignments = (
|
||||
([], []) if schema.name.name.inplace else gen_returns(schema)
|
||||
)
|
||||
args.extend(ret_declarations)
|
||||
|
||||
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user