mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix meta error in _convert_weight_to_int4pack (#130915)
This PR is to fix meta error in _convert_weight_to_int4pack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130915 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
2bf649f5ae
commit
bceb91222c
|
|
@ -6142,6 +6142,27 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||
torch._int_mm(a_int8, b_int8, out=c_int32_result)
|
||||
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@onlyNativeDeviceTypes
|
||||
def test__convert_weight_to_int4pack(self, device):
|
||||
# TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
|
||||
test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)]
|
||||
if self.device_type == 'cuda' and not SM80OrLater:
|
||||
self.skipTest("requires SM80 or later")
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
if not CDNA2OrLater():
|
||||
self.skipTest("_int4_mm is supported only for CDNA2 or later")
|
||||
|
||||
torch.manual_seed(1)
|
||||
for shape, innerKTiles in test_list:
|
||||
b = torch.rand(shape, dtype=torch.bfloat16, device=device)
|
||||
b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32)
|
||||
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles)
|
||||
b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles)
|
||||
self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@onlyNativeDeviceTypes
|
||||
|
|
|
|||
|
|
@ -3262,7 +3262,7 @@ def meta__convert_weight_to_int4pack(w, inner_k_tiles):
|
|||
lambda: f"expected w to be uint8, got {w.dtype}",
|
||||
)
|
||||
n = w.size(0)
|
||||
k = w.size(1)
|
||||
k = w.size(1) * 2 # w is [n][k / 2] uint8
|
||||
return w.new_empty(
|
||||
(
|
||||
n // 8,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user