Add new ops in fallback ops (#154251)

Fixes #ISSUE_NUMBER

## Background

Task: [T222738229](https://www.internalfb.com/intern/tasks/?t=222738229)

It's the first starter task on the project **_Enabling TorchNative Standalone on Whisper_**.  We are using cshim to create a layer of abstraction between _**libtorch**_ and **_AOTInductor generated artifacts_**.

So we needed to add an entry in the cshim for every API surface in libtorch. And we only care about operators that AOTInductor does not handle. And for this task, we only wanted to add it for the following ops.

## What I've done?

4 new fallback ops are added that show up in the Whisper model. (torchgen/aoti/fallback_ops.py)

- aten.permute (default)
- aten.squueze (dim)
- aten.abs (default)
- aten.hann_window (default)

Then I ran the below command to generate new header C shim header files. As it says [here](7e86a7c015/torchgen/gen.py (L2424-L2436%20for%20details))
`python torchgen/gen.py --update-aoti-c-shim`

Then, `python setup.py develop` to rebuild PyTorch

## Testing

Also 4 new tests have been added on test/inductor/test_aot_inductor.py

- test_proxy_executor_permute
- test_proxy_executor_abs
- test_proxy_executor_squeeze
- test_proxy_executor_hann

I ran these commands to test it (inside local pytorch root folder):

`python test/inductor/test_aot_inductor.py -k test_proxy_executor_permute`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_abs`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_squeeze`
`python test/inductor/test_aot_inductor.py -k test_proxy_executor_hann`

## NOTE:
I didn't see any order between the tests inside _test/inductor/test_aot_inductor.py_. That's why, I added new tests just after the test given in the example.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154251
Approved by: https://github.com/angelayi
This commit is contained in:
Serhat Gundem 2025-05-28 22:11:05 +00:00 committed by PyTorch MergeBot
parent d2f506cae8
commit 11129d9317
6 changed files with 68 additions and 0 deletions

View File

@ -3625,6 +3625,54 @@ class AOTInductorTestsTemplate:
m = M()
self.check_model(m, example_args, dynamic_shapes=dynamic_shapes)
def test_proxy_executor_permute(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.ops.aten.permute.default(x, [0, 2, 1])
example_args = (torch.randn((1, 3001, 201), dtype=torch.complex64),)
m = M()
self.check_model(m, example_args)
def test_proxy_executor_abs(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.ops.aten.abs.default(x)
example_args = (torch.randn((1, 3001, 201), dtype=torch.complex64),)
m = M()
self.check_model(m, example_args)
def test_proxy_executor_squeeze(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.ops.aten.squeeze.dim(x, 0)
example_args = (torch.randn((1, 300, 201), dtype=torch.complex64),)
m = M()
self.check_model(m, example_args)
def test_proxy_executor_hann(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self):
return torch.ops.aten.hann_window.default(400)
example_args = ()
m = M()
self.check_model(m, example_args)
def test_fqn(self):
class NestedChild(torch.nn.Module):
def __init__(self) -> None:

View File

@ -43,6 +43,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTen
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
@ -79,6 +80,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d_backward(A
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_gcd(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_geqrf(AtenTensorHandle self, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0);
@ -106,6 +108,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_native_dropout(AtenTensorHandle
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0);
@ -136,6 +139,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle se
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_squeeze_dim(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View File

@ -50,6 +50,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTenso
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
@ -86,6 +87,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d_backward(
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_gcd(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_geqrf(AtenTensorHandle self, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0);
@ -112,6 +114,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_native_dropout(AtenTensorHandle
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0);
@ -142,6 +145,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle s
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_squeeze_dim(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View File

@ -24,6 +24,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attent
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0);
@ -48,6 +49,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cummin(AtenTensorHandle self, in
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0);
@ -67,6 +69,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Tensor(AtenTensorHandle self
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0);
@ -96,6 +99,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_slice_Tensor(AtenTensorHandle se
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_squeeze_dim(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View File

@ -16,6 +16,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha);
@ -28,6 +29,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_convolution_backward(AtenTensorH
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
@ -36,6 +38,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter_backward(AtenTens
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
@ -51,6 +54,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_resize_as_(AtenTensorHandle self
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_squeeze_dim(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0);

View File

@ -53,6 +53,7 @@ inductor_fallback_ops = {
"aten._trilinear.default",
"aten._weight_int4pack_mm.default",
"aten._weight_int8pack_mm.default",
"aten.abs.default",
"aten.adaptive_max_pool2d_backward.default",
"aten.adaptive_max_pool2d.default",
"aten.adaptive_max_pool3d_backward.default",
@ -89,6 +90,7 @@ inductor_fallback_ops = {
"aten.gcd.default",
"aten.geqrf.default",
"aten.grid_sampler_2d_backward.default",
"aten.hann_window.default",
"aten.histc.default",
"aten.histogram.bin_ct",
"aten.index_put.default",
@ -116,6 +118,7 @@ inductor_fallback_ops = {
"aten.nonzero.default",
"aten.normal_functional.default",
"aten.ormqr.default",
"aten.permute.default",
"aten.polar.default",
"aten.pow.Scalar",
"aten.pow.Tensor_Scalar",
@ -146,6 +149,7 @@ inductor_fallback_ops = {
"aten.soft_margin_loss_backward.default",
"aten.sort.default",
"aten.sort.stable",
"aten.squeeze.dim",
"aten.to_sparse.default",
"aten.topk.default",
"aten.triangular_solve.default",