mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DCP][Quantization] Fix the issue when scale vector is in a different SafeTensors file (#162214)
Summary: The current dequantization implementation assumes that the weight and scale tenors are in the same SafeTensors files. This diff fixes the issue to support the case when these could be in different files. Test Plan: buck test fbcode//caffe2/test/distributed/checkpoint\:test_quantized_hf_storage Buck UI: https://www.internalfb.com/buck2/532bf151-bb40-41fd-b080-ff898675afe2 Test UI: https://www.internalfb.com/intern/testinfra/testrun/15199648851011082 Rollback Plan: Differential Revision: D81718598 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162214 Approved by: https://github.com/wwwjn
This commit is contained in:
parent
79fcd5247a
commit
01ab325cc2
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
|
|
@ -23,38 +23,70 @@ class TestQuantizedHfStorage(TestCase):
|
|||
self.temp_dir.cleanup()
|
||||
|
||||
def test_dequantization(self):
|
||||
"""Test that quantized tensors are properly dequantized during read operations."""
|
||||
"""Test quantized tensors with weights and scales in both same and different files."""
|
||||
reader = QuantizedHuggingFaceStorageReader(self.path, thread_count=1)
|
||||
|
||||
# Test data
|
||||
quantized_tensor = torch.ones(4, 4, dtype=torch.float32)
|
||||
scale_inv = torch.tensor([[2.0]], dtype=torch.float32)
|
||||
# Test data for two different weights
|
||||
quantized_tensor1 = torch.ones(4, 4, dtype=torch.float32)
|
||||
quantized_tensor2 = (
|
||||
torch.ones(4, 4, dtype=torch.float32) * 3.0
|
||||
) # Different values
|
||||
scale_inv1 = torch.tensor([[2.0]], dtype=torch.float32)
|
||||
scale_inv2 = torch.tensor([[0.5]], dtype=torch.float32) # Different scale
|
||||
|
||||
# Mock the safetensors file for reading data
|
||||
mock_file = MagicMock()
|
||||
# Define weight and scale tensor names
|
||||
weight1_fqn = "model.layers.0.self_attn.q_proj.weight" # Scale in same file
|
||||
scale1_fqn = "model.layers.0.self_attn.q_proj.weight_scale_inv"
|
||||
weight2_fqn = (
|
||||
"model.layers.0.self_attn.k_proj.weight" # Scale in different file
|
||||
)
|
||||
scale2_fqn = "model.layers.0.self_attn.k_proj.weight_scale_inv"
|
||||
|
||||
# Mock get_slice to return a tensor that can be sliced
|
||||
def mock_get_slice(tensor_name):
|
||||
mock_tensor = MagicMock()
|
||||
mock_tensor.__getitem__ = lambda self, slices: quantized_tensor
|
||||
return mock_tensor
|
||||
|
||||
mock_file.get_slice = mock_get_slice
|
||||
mock_file.get_tensor.return_value = scale_inv
|
||||
file1_name = "model-00001-of-00002.safetensors"
|
||||
file2_name = "model-00002-of-00002.safetensors"
|
||||
|
||||
# Setup weight-scale mapping and file locations
|
||||
reader._weight_scale_mapping = {
|
||||
"model.layers.0.self_attn.kv_b_proj.weight": "model.layers.0.self_attn.kv_b_proj.weight_scale_inv",
|
||||
weight1_fqn: scale1_fqn,
|
||||
weight2_fqn: scale2_fqn,
|
||||
}
|
||||
reader._weight_map = {
|
||||
weight1_fqn: file1_name, # Weight in file 1
|
||||
scale1_fqn: file1_name, # Scale also in file 1 (same file scenario)
|
||||
weight2_fqn: file1_name, # Weight in file 1
|
||||
scale2_fqn: file2_name, # Scale in file 2 (different file scenario)
|
||||
}
|
||||
|
||||
# Create a read request for quantized tensor
|
||||
read_item = ReadItem(
|
||||
# Mock the main safetensors file (file1)
|
||||
mock_file1 = MagicMock()
|
||||
|
||||
# Mock get_slice to return different tensors based on tensor name
|
||||
def mock_get_slice(tensor_name):
|
||||
mock_tensor = MagicMock()
|
||||
if tensor_name == weight1_fqn:
|
||||
mock_tensor.__getitem__ = lambda _, __: quantized_tensor1
|
||||
elif tensor_name == weight2_fqn:
|
||||
mock_tensor.__getitem__ = lambda _, __: quantized_tensor2
|
||||
return mock_tensor
|
||||
|
||||
mock_file1.get_slice = mock_get_slice
|
||||
|
||||
# Mock get_tensor for same-file scale (scale1)
|
||||
mock_file1.get_tensor.return_value = scale_inv1
|
||||
|
||||
# Mock the cross-file safetensors file (file2) for scale2
|
||||
mock_file2 = MagicMock()
|
||||
mock_file2.get_tensor.return_value = scale_inv2
|
||||
|
||||
# Test 1: Same-file scenario (weight1 + scale1 both in file1)
|
||||
read_item1 = ReadItem(
|
||||
type=LoadItemType.TENSOR,
|
||||
storage_index=MetadataIndex(
|
||||
fqn="model.layers.0.self_attn.kv_b_proj.weight",
|
||||
fqn=weight1_fqn,
|
||||
offset=torch.Size([0, 0]),
|
||||
),
|
||||
dest_index=MetadataIndex(
|
||||
fqn="model.layers.0.self_attn.kv_b_proj.weight",
|
||||
fqn=weight1_fqn,
|
||||
offset=torch.Size([0, 0]),
|
||||
),
|
||||
storage_offsets=[0, 0],
|
||||
|
|
@ -62,22 +94,73 @@ class TestQuantizedHfStorage(TestCase):
|
|||
lengths=[4, 4],
|
||||
)
|
||||
|
||||
# Mock planner
|
||||
target_tensor = torch.zeros(4, 4, dtype=torch.float32)
|
||||
mock_planner = MagicMock()
|
||||
mock_planner.resolve_tensor.return_value = target_tensor
|
||||
target_tensor1 = torch.zeros(4, 4, dtype=torch.float32)
|
||||
mock_planner1 = MagicMock()
|
||||
mock_planner1.resolve_tensor.return_value = target_tensor1
|
||||
|
||||
# Test the _process_read_request method
|
||||
reader._process_read_request(mock_file, read_item, mock_planner)
|
||||
# Process first weight (same file scenario)
|
||||
reader._process_read_request(mock_file1, read_item1, mock_planner1)
|
||||
|
||||
# Verify the tensor was dequantized (ones * 2.0 = twos)
|
||||
expected_result = torch.ones(4, 4, dtype=torch.float32) * 2.0
|
||||
mock_planner.commit_tensor.assert_called_once()
|
||||
# Verify first tensor was dequantized (ones * 2.0 = twos)
|
||||
expected_result1 = torch.ones(4, 4, dtype=torch.float32) * 2.0
|
||||
mock_planner1.commit_tensor.assert_called_once()
|
||||
|
||||
# Check that target_tensor was updated correctly
|
||||
args, _ = mock_planner.commit_tensor.call_args
|
||||
committed_tensor = args[1] # second argument is the tensor
|
||||
torch.testing.assert_close(committed_tensor, expected_result)
|
||||
# Check that target_tensor1 was updated correctly
|
||||
args1, _ = mock_planner1.commit_tensor.call_args
|
||||
committed_tensor1 = args1[1]
|
||||
torch.testing.assert_close(committed_tensor1, expected_result1)
|
||||
|
||||
# Test 2: Cross-file scenario (weight2 in file1, scale2 in file2)
|
||||
read_item2 = ReadItem(
|
||||
type=LoadItemType.TENSOR,
|
||||
storage_index=MetadataIndex(
|
||||
fqn=weight2_fqn,
|
||||
offset=torch.Size([0, 0]),
|
||||
),
|
||||
dest_index=MetadataIndex(
|
||||
fqn=weight2_fqn,
|
||||
offset=torch.Size([0, 0]),
|
||||
),
|
||||
storage_offsets=[0, 0],
|
||||
dest_offsets=[0, 0],
|
||||
lengths=[4, 4],
|
||||
)
|
||||
|
||||
target_tensor2 = torch.zeros(4, 4, dtype=torch.float32)
|
||||
mock_planner2 = MagicMock()
|
||||
mock_planner2.resolve_tensor.return_value = target_tensor2
|
||||
|
||||
# Mock the entire safetensors module since it may not be available in test environment
|
||||
mock_safetensors = MagicMock()
|
||||
mock_safe_open = MagicMock()
|
||||
mock_safetensors.safe_open = mock_safe_open
|
||||
|
||||
# Set up the mock to return a context manager that yields mock_file2
|
||||
mock_safe_open.return_value.__enter__.return_value = mock_file2
|
||||
mock_safe_open.return_value.__exit__.return_value = False
|
||||
|
||||
# Mock the module import and safe_open function
|
||||
with patch.dict("sys.modules", {"safetensors": mock_safetensors}):
|
||||
# Process second weight (cross-file scenario)
|
||||
reader._process_read_request(mock_file1, read_item2, mock_planner2)
|
||||
|
||||
# Verify safe_open was called with the correct file path
|
||||
expected_path = f"{self.path}/{file2_name}"
|
||||
mock_safe_open.assert_called_once()
|
||||
call_args = mock_safe_open.call_args[0]
|
||||
self.assertEqual(str(call_args[0]), expected_path)
|
||||
|
||||
# Verify the scale tensor was loaded from the correct file
|
||||
mock_file2.get_tensor.assert_called_once_with(scale2_fqn)
|
||||
|
||||
# Verify second tensor was dequantized (3.0 * 0.5 = 1.5)
|
||||
expected_result2 = torch.ones(4, 4, dtype=torch.float32) * 3.0 * 0.5 # 1.5
|
||||
mock_planner2.commit_tensor.assert_called_once()
|
||||
|
||||
# Check that target_tensor2 was updated correctly
|
||||
args2, _ = mock_planner2.commit_tensor.call_args
|
||||
committed_tensor2 = args2[1]
|
||||
torch.testing.assert_close(committed_tensor2, expected_result2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
|
|||
self.target_dtype: torch.dtype = target_dtype
|
||||
self.block_size: int = block_size
|
||||
self._weight_scale_mapping: dict[str, str] = {}
|
||||
self._scale_tensor_cache: dict[str, torch.Tensor] = {}
|
||||
# Track which file contains each tensor
|
||||
self._weight_map: dict[str, str] = {}
|
||||
|
||||
def read_metadata(self) -> Any:
|
||||
self._load_quantization_metadata()
|
||||
|
|
@ -67,6 +68,9 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
|
|||
|
||||
def _build_weight_scale_mapping(self, weight_map: dict[str, str]):
|
||||
"""Analyze and build weight-scale tensor pairs from weight mapping."""
|
||||
# Store the complete weight map for file location lookups
|
||||
self._weight_map = weight_map
|
||||
|
||||
for tensor_name in weight_map.keys():
|
||||
if tensor_name.endswith(".weight_scale_inv"):
|
||||
weight_name = tensor_name.replace(".weight_scale_inv", ".weight")
|
||||
|
|
@ -206,14 +210,26 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
|
|||
quantized_tensor = safetensor_file.get_slice(tensor_fqn)[weight_slices]
|
||||
|
||||
# Load the corresponding scale inverse tensor
|
||||
# For scale tensors, we typically need the full tensor for proper block alignment
|
||||
if scale_fqn not in self._scale_tensor_cache:
|
||||
scale_inv = safetensor_file.get_tensor(
|
||||
scale_fqn
|
||||
) # Load full scale tensor
|
||||
self._scale_tensor_cache[scale_fqn] = scale_inv
|
||||
# Use weight_map to find the correct file for the scale tensor
|
||||
scale_file_name = self._weight_map.get(scale_fqn)
|
||||
if scale_file_name is None:
|
||||
raise ValueError(f"Scale tensor {scale_fqn} not found in weight_map")
|
||||
|
||||
# Check if scale tensor is in the same file as the weight tensor
|
||||
weight_file_name = self._weight_map.get(tensor_fqn)
|
||||
|
||||
if scale_file_name == weight_file_name:
|
||||
# Scale tensor is in the same file, use current handle
|
||||
scale_inv = safetensor_file.get_tensor(scale_fqn)
|
||||
else:
|
||||
scale_inv = self._scale_tensor_cache[scale_fqn]
|
||||
# Scale tensor is in a different file, need to open it
|
||||
from safetensors import safe_open # type: ignore[import]
|
||||
|
||||
scale_file_path = Path(self.path) / scale_file_name
|
||||
with safe_open(
|
||||
scale_file_path, framework="pt", device="cpu"
|
||||
) as scale_file:
|
||||
scale_inv = scale_file.get_tensor(scale_fqn)
|
||||
|
||||
# Perform dequantization
|
||||
dequantized_tensor = self._dequantize_tensor(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user