Fix meta error in _convert_weight_to_int4pack (#130915)

This PR is to fix meta error in _convert_weight_to_int4pack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130915
Approved by: https://github.com/jerryzh168
This commit is contained in:
Jiang, Yanbing 2024-07-26 08:36:29 +00:00 committed by PyTorch MergeBot
parent 2bf649f5ae
commit bceb91222c
2 changed files with 22 additions and 1 deletions

View File

@ -6142,6 +6142,27 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
torch._int_mm(a_int8, b_int8, out=c_int32_result)
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyNativeDeviceTypes
def test__convert_weight_to_int4pack(self, device):
# TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)]
if self.device_type == 'cuda' and not SM80OrLater:
self.skipTest("requires SM80 or later")
if TEST_WITH_ROCM:
if not CDNA2OrLater():
self.skipTest("_int4_mm is supported only for CDNA2 or later")
torch.manual_seed(1)
for shape, innerKTiles in test_list:
b = torch.rand(shape, dtype=torch.bfloat16, device=device)
b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32)
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles)
b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles)
self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyNativeDeviceTypes

View File

@ -3262,7 +3262,7 @@ def meta__convert_weight_to_int4pack(w, inner_k_tiles):
lambda: f"expected w to be uint8, got {w.dtype}",
)
n = w.size(0)
k = w.size(1)
k = w.size(1) * 2 # w is [n][k / 2] uint8
return w.new_empty(
(
n // 8,