diff --git a/test/test_linalg.py b/test/test_linalg.py index 046e58bc3eb..2647b53a922 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -6142,6 +6142,27 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: torch._int_mm(a_int8, b_int8, out=c_int32_result) self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) + @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") + @onlyNativeDeviceTypes + def test__convert_weight_to_int4pack(self, device): + # TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead + test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)] + if self.device_type == 'cuda' and not SM80OrLater: + self.skipTest("requires SM80 or later") + + if TEST_WITH_ROCM: + if not CDNA2OrLater(): + self.skipTest("_int4_mm is supported only for CDNA2 or later") + + torch.manual_seed(1) + for shape, innerKTiles in test_list: + b = torch.rand(shape, dtype=torch.bfloat16, device=device) + b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32) + b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles) + b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles) + self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape) + @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyNativeDeviceTypes diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 7e197774e3a..b62d69774e7 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3262,7 +3262,7 @@ def meta__convert_weight_to_int4pack(w, inner_k_tiles): lambda: f"expected w to be uint8, got {w.dtype}", ) n = w.size(0) - k = w.size(1) + k = w.size(1) * 2 # w is [n][k / 2] uint8 return w.new_empty( ( n // 8,