mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTI] Add _weight_int4pack_mm to the C shim fallback list (#151059)
Summary: As title Pull Request resolved: https://github.com/pytorch/pytorch/pull/151059 Approved by: https://github.com/yushangdi
This commit is contained in:
parent
12281f9c18
commit
a78ac409b5
|
|
@ -33,6 +33,7 @@ from torch.testing._internal.common_device_type import (
|
|||
skipCUDAIf,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
_group_quantize_tensor,
|
||||
skip_if_no_torchvision,
|
||||
skipIfNoFBGEMM,
|
||||
)
|
||||
|
|
@ -42,6 +43,7 @@ from torch.testing._internal.common_utils import (
|
|||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
TEST_WITH_ROCM,
|
||||
|
|
@ -4936,6 +4938,42 @@ class AOTInductorTestsTemplate:
|
|||
)
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
@skipIfXpu(
|
||||
msg="aten::convert_weight_to_int4pack is not currently implemented for XPU"
|
||||
)
|
||||
@parametrize("m", [32])
|
||||
@parametrize("n", [64])
|
||||
@parametrize("q_group", [32, 64])
|
||||
@parametrize("num_groups", [1, 2])
|
||||
def test__weight_int4pack_mm(self, m, n, q_group, num_groups):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, weight, scale_and_zeros) -> None:
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
self.scale_and_zeros = scale_and_zeros
|
||||
|
||||
def forward(self, a):
|
||||
return torch._weight_int4pack_mm(
|
||||
a, self.weight, q_group, self.scale_and_zeros
|
||||
)
|
||||
|
||||
def convert_weight_to_int4pack(b):
|
||||
b_int32, b_scales_and_zeros = _group_quantize_tensor(
|
||||
b, n_bit=4, q_group_size=q_group
|
||||
)
|
||||
b_int4pack = torch._convert_weight_to_int4pack(b_int32, innerKTiles=2)
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
k = q_group * num_groups
|
||||
a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16)
|
||||
b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16)
|
||||
b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b)
|
||||
model = Model(b_int4pack, b_scales_and_zeros_f32)
|
||||
self.check_model(model, (a,))
|
||||
|
||||
def test_assert_tensor_meta(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTe
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTensorHandle input_gates, AtenTensorHandle hidden_gates, AtenTensorHandle cx, AtenTensorHandle* input_bias, AtenTensorHandle* hidden_bias, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ inductor_fallback_ops = {
|
|||
"aten._thnn_fused_lstm_cell.default",
|
||||
"aten._to_sparse.default",
|
||||
"aten._trilinear.default",
|
||||
"aten._weight_int4pack_mm.default",
|
||||
"aten._weight_int8pack_mm.default",
|
||||
"aten.adaptive_max_pool2d_backward.default",
|
||||
"aten.adaptive_max_pool2d.default",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user