[DCP][Quantization] Fix the issue when scale vector is in a different SafeTensors file (#162214)

Summary: The current dequantization implementation assumes that the weight and scale tenors are in the same SafeTensors files. This diff fixes the issue to support the case when these could be in different files.

Test Plan:
buck test fbcode//caffe2/test/distributed/checkpoint\:test_quantized_hf_storage

Buck UI: https://www.internalfb.com/buck2/532bf151-bb40-41fd-b080-ff898675afe2
Test UI: https://www.internalfb.com/intern/testinfra/testrun/15199648851011082

Rollback Plan:

Differential Revision: D81718598

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162214
Approved by: https://github.com/wwwjn
This commit is contained in:
Saurabh Mishra 2025-09-05 22:43:58 +00:00 committed by PyTorch MergeBot
parent 79fcd5247a
commit 01ab325cc2
2 changed files with 140 additions and 41 deletions

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: distributed checkpointing"]
import tempfile
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import torch
from torch.distributed.checkpoint.metadata import MetadataIndex
@ -23,38 +23,70 @@ class TestQuantizedHfStorage(TestCase):
self.temp_dir.cleanup()
def test_dequantization(self):
"""Test that quantized tensors are properly dequantized during read operations."""
"""Test quantized tensors with weights and scales in both same and different files."""
reader = QuantizedHuggingFaceStorageReader(self.path, thread_count=1)
# Test data
quantized_tensor = torch.ones(4, 4, dtype=torch.float32)
scale_inv = torch.tensor([[2.0]], dtype=torch.float32)
# Test data for two different weights
quantized_tensor1 = torch.ones(4, 4, dtype=torch.float32)
quantized_tensor2 = (
torch.ones(4, 4, dtype=torch.float32) * 3.0
) # Different values
scale_inv1 = torch.tensor([[2.0]], dtype=torch.float32)
scale_inv2 = torch.tensor([[0.5]], dtype=torch.float32) # Different scale
# Mock the safetensors file for reading data
mock_file = MagicMock()
# Define weight and scale tensor names
weight1_fqn = "model.layers.0.self_attn.q_proj.weight" # Scale in same file
scale1_fqn = "model.layers.0.self_attn.q_proj.weight_scale_inv"
weight2_fqn = (
"model.layers.0.self_attn.k_proj.weight" # Scale in different file
)
scale2_fqn = "model.layers.0.self_attn.k_proj.weight_scale_inv"
# Mock get_slice to return a tensor that can be sliced
def mock_get_slice(tensor_name):
mock_tensor = MagicMock()
mock_tensor.__getitem__ = lambda self, slices: quantized_tensor
return mock_tensor
mock_file.get_slice = mock_get_slice
mock_file.get_tensor.return_value = scale_inv
file1_name = "model-00001-of-00002.safetensors"
file2_name = "model-00002-of-00002.safetensors"
# Setup weight-scale mapping and file locations
reader._weight_scale_mapping = {
"model.layers.0.self_attn.kv_b_proj.weight": "model.layers.0.self_attn.kv_b_proj.weight_scale_inv",
weight1_fqn: scale1_fqn,
weight2_fqn: scale2_fqn,
}
reader._weight_map = {
weight1_fqn: file1_name, # Weight in file 1
scale1_fqn: file1_name, # Scale also in file 1 (same file scenario)
weight2_fqn: file1_name, # Weight in file 1
scale2_fqn: file2_name, # Scale in file 2 (different file scenario)
}
# Create a read request for quantized tensor
read_item = ReadItem(
# Mock the main safetensors file (file1)
mock_file1 = MagicMock()
# Mock get_slice to return different tensors based on tensor name
def mock_get_slice(tensor_name):
mock_tensor = MagicMock()
if tensor_name == weight1_fqn:
mock_tensor.__getitem__ = lambda _, __: quantized_tensor1
elif tensor_name == weight2_fqn:
mock_tensor.__getitem__ = lambda _, __: quantized_tensor2
return mock_tensor
mock_file1.get_slice = mock_get_slice
# Mock get_tensor for same-file scale (scale1)
mock_file1.get_tensor.return_value = scale_inv1
# Mock the cross-file safetensors file (file2) for scale2
mock_file2 = MagicMock()
mock_file2.get_tensor.return_value = scale_inv2
# Test 1: Same-file scenario (weight1 + scale1 both in file1)
read_item1 = ReadItem(
type=LoadItemType.TENSOR,
storage_index=MetadataIndex(
fqn="model.layers.0.self_attn.kv_b_proj.weight",
fqn=weight1_fqn,
offset=torch.Size([0, 0]),
),
dest_index=MetadataIndex(
fqn="model.layers.0.self_attn.kv_b_proj.weight",
fqn=weight1_fqn,
offset=torch.Size([0, 0]),
),
storage_offsets=[0, 0],
@ -62,22 +94,73 @@ class TestQuantizedHfStorage(TestCase):
lengths=[4, 4],
)
# Mock planner
target_tensor = torch.zeros(4, 4, dtype=torch.float32)
mock_planner = MagicMock()
mock_planner.resolve_tensor.return_value = target_tensor
target_tensor1 = torch.zeros(4, 4, dtype=torch.float32)
mock_planner1 = MagicMock()
mock_planner1.resolve_tensor.return_value = target_tensor1
# Test the _process_read_request method
reader._process_read_request(mock_file, read_item, mock_planner)
# Process first weight (same file scenario)
reader._process_read_request(mock_file1, read_item1, mock_planner1)
# Verify the tensor was dequantized (ones * 2.0 = twos)
expected_result = torch.ones(4, 4, dtype=torch.float32) * 2.0
mock_planner.commit_tensor.assert_called_once()
# Verify first tensor was dequantized (ones * 2.0 = twos)
expected_result1 = torch.ones(4, 4, dtype=torch.float32) * 2.0
mock_planner1.commit_tensor.assert_called_once()
# Check that target_tensor was updated correctly
args, _ = mock_planner.commit_tensor.call_args
committed_tensor = args[1] # second argument is the tensor
torch.testing.assert_close(committed_tensor, expected_result)
# Check that target_tensor1 was updated correctly
args1, _ = mock_planner1.commit_tensor.call_args
committed_tensor1 = args1[1]
torch.testing.assert_close(committed_tensor1, expected_result1)
# Test 2: Cross-file scenario (weight2 in file1, scale2 in file2)
read_item2 = ReadItem(
type=LoadItemType.TENSOR,
storage_index=MetadataIndex(
fqn=weight2_fqn,
offset=torch.Size([0, 0]),
),
dest_index=MetadataIndex(
fqn=weight2_fqn,
offset=torch.Size([0, 0]),
),
storage_offsets=[0, 0],
dest_offsets=[0, 0],
lengths=[4, 4],
)
target_tensor2 = torch.zeros(4, 4, dtype=torch.float32)
mock_planner2 = MagicMock()
mock_planner2.resolve_tensor.return_value = target_tensor2
# Mock the entire safetensors module since it may not be available in test environment
mock_safetensors = MagicMock()
mock_safe_open = MagicMock()
mock_safetensors.safe_open = mock_safe_open
# Set up the mock to return a context manager that yields mock_file2
mock_safe_open.return_value.__enter__.return_value = mock_file2
mock_safe_open.return_value.__exit__.return_value = False
# Mock the module import and safe_open function
with patch.dict("sys.modules", {"safetensors": mock_safetensors}):
# Process second weight (cross-file scenario)
reader._process_read_request(mock_file1, read_item2, mock_planner2)
# Verify safe_open was called with the correct file path
expected_path = f"{self.path}/{file2_name}"
mock_safe_open.assert_called_once()
call_args = mock_safe_open.call_args[0]
self.assertEqual(str(call_args[0]), expected_path)
# Verify the scale tensor was loaded from the correct file
mock_file2.get_tensor.assert_called_once_with(scale2_fqn)
# Verify second tensor was dequantized (3.0 * 0.5 = 1.5)
expected_result2 = torch.ones(4, 4, dtype=torch.float32) * 3.0 * 0.5 # 1.5
mock_planner2.commit_tensor.assert_called_once()
# Check that target_tensor2 was updated correctly
args2, _ = mock_planner2.commit_tensor.call_args
committed_tensor2 = args2[1]
torch.testing.assert_close(committed_tensor2, expected_result2)
if __name__ == "__main__":

View File

@ -48,7 +48,8 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
self.target_dtype: torch.dtype = target_dtype
self.block_size: int = block_size
self._weight_scale_mapping: dict[str, str] = {}
self._scale_tensor_cache: dict[str, torch.Tensor] = {}
# Track which file contains each tensor
self._weight_map: dict[str, str] = {}
def read_metadata(self) -> Any:
self._load_quantization_metadata()
@ -67,6 +68,9 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
def _build_weight_scale_mapping(self, weight_map: dict[str, str]):
"""Analyze and build weight-scale tensor pairs from weight mapping."""
# Store the complete weight map for file location lookups
self._weight_map = weight_map
for tensor_name in weight_map.keys():
if tensor_name.endswith(".weight_scale_inv"):
weight_name = tensor_name.replace(".weight_scale_inv", ".weight")
@ -206,14 +210,26 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
quantized_tensor = safetensor_file.get_slice(tensor_fqn)[weight_slices]
# Load the corresponding scale inverse tensor
# For scale tensors, we typically need the full tensor for proper block alignment
if scale_fqn not in self._scale_tensor_cache:
scale_inv = safetensor_file.get_tensor(
scale_fqn
) # Load full scale tensor
self._scale_tensor_cache[scale_fqn] = scale_inv
# Use weight_map to find the correct file for the scale tensor
scale_file_name = self._weight_map.get(scale_fqn)
if scale_file_name is None:
raise ValueError(f"Scale tensor {scale_fqn} not found in weight_map")
# Check if scale tensor is in the same file as the weight tensor
weight_file_name = self._weight_map.get(tensor_fqn)
if scale_file_name == weight_file_name:
# Scale tensor is in the same file, use current handle
scale_inv = safetensor_file.get_tensor(scale_fqn)
else:
scale_inv = self._scale_tensor_cache[scale_fqn]
# Scale tensor is in a different file, need to open it
from safetensors import safe_open # type: ignore[import]
scale_file_path = Path(self.path) / scale_file_name
with safe_open(
scale_file_path, framework="pt", device="cpu"
) as scale_file:
scale_inv = scale_file.get_tensor(scale_fqn)
# Perform dequantization
dequantized_tensor = self._dequantize_tensor(