[AOTI] Add _weight_int4pack_mm to the C shim fallback list (#151059)

Summary: As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151059
Approved by: https://github.com/yushangdi
This commit is contained in:
Bin Bao 2025-04-10 15:36:41 -07:00 committed by PyTorch MergeBot
parent 12281f9c18
commit a78ac409b5
3 changed files with 40 additions and 0 deletions

View File

@ -33,6 +33,7 @@ from torch.testing._internal.common_device_type import (
skipCUDAIf,
)
from torch.testing._internal.common_quantization import (
_group_quantize_tensor,
skip_if_no_torchvision,
skipIfNoFBGEMM,
)
@ -42,6 +43,7 @@ from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_MACOS,
IS_WINDOWS,
parametrize,
skipIfRocm,
skipIfXpu,
TEST_WITH_ROCM,
@ -4936,6 +4938,42 @@ class AOTInductorTestsTemplate:
)
self.check_model(Model(), example_inputs)
@skipIfXpu(
msg="aten::convert_weight_to_int4pack is not currently implemented for XPU"
)
@parametrize("m", [32])
@parametrize("n", [64])
@parametrize("q_group", [32, 64])
@parametrize("num_groups", [1, 2])
def test__weight_int4pack_mm(self, m, n, q_group, num_groups):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, weight, scale_and_zeros) -> None:
super().__init__()
self.weight = weight
self.scale_and_zeros = scale_and_zeros
def forward(self, a):
return torch._weight_int4pack_mm(
a, self.weight, q_group, self.scale_and_zeros
)
def convert_weight_to_int4pack(b):
b_int32, b_scales_and_zeros = _group_quantize_tensor(
b, n_bit=4, q_group_size=q_group
)
b_int4pack = torch._convert_weight_to_int4pack(b_int32, innerKTiles=2)
return b_int4pack, b_scales_and_zeros
k = q_group * num_groups
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b)
model = Model(b_int4pack, b_scales_and_zeros_f32)
self.check_model(model, (a,))
def test_assert_tensor_meta(self):
class Module(torch.nn.Module):
def forward(self, x):

View File

@ -49,6 +49,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTe
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTensorHandle input_gates, AtenTensorHandle hidden_gates, AtenTensorHandle cx, AtenTensorHandle* input_bias, AtenTensorHandle* hidden_bias, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View File

@ -51,6 +51,7 @@ inductor_fallback_ops = {
"aten._thnn_fused_lstm_cell.default",
"aten._to_sparse.default",
"aten._trilinear.default",
"aten._weight_int4pack_mm.default",
"aten._weight_int8pack_mm.default",
"aten.adaptive_max_pool2d_backward.default",
"aten.adaptive_max_pool2d.default",