pytorch/torch/distributed/checkpoint
Saurabh Mishra 134dfbeaef [DCP] DTensor slice dequantization with proper block alignment (#163532)
Summary:
When loading quantized tensors with DTensor slicing, the dequantization process was producing numerically incorrect results due to improper block-to-slice coordinate mapping. The previous implementation calculated block boundaries relative to the sliced tensor dimensions instead of the original full tensor dimensions, causing scale factors to be applied to wrong tensor regions.

This fix addresses the issue by:

1. **Proper coordinate mapping**: Added `_get_slice_to_block_mapping()` to correctly map tensor slices to quantization blocks using global coordinates from the full tensor shape.

3. **Block-aligned dequantization**: Updated `_dequantize_tensor()` to use proper block intersection logic, ensuring scale factors are applied to the correct portions of sliced tensors.

The fix ensures that when DTensor requests a slice of a quantized tensor, the dequantization correctly identifies which quantization blocks intersect with the requested slice and applies the appropriate scale factors to the right tensor regions.

Test Plan:
Tested with DTensor configurations where quantized tensors are sliced across different dimensions. Verified that:
1. Dequantized tensor values are numerically correct
2. Block boundaries are properly calculated relative to full tensor shape
3. Scale factors are applied to correct tensor regions
4. Tensor shapes map is built efficiently using only metadata

Correctness validation using https://github.com/wwwjn/torchtitan/blob/dsv3-sd-test/tests/fsdp_dequantized_load.py
```
{
  "model.layers.0.mlp.gate_proj.weight": {
    "mse": 4.30626645453458e-11,
    "mae": 9.98388827611052e-07,
    "max_abs_diff": 0.0009703934192657471,
    "cosine_similarity": 1.010810375213623,
    "relative_error": 0.001330620958469808,
    "kl_divergence_1_to_2": "6.563401e-08",
    "kl_divergence_2_to_1": "-6.522914e-08",
    "js_divergence": 1.3711876079014476e-10,
    "shape": [
      18432,
      7168
    ],
    "t1_stats": {
      "min": -0.4453125,
      "max": 0.30859375,
      "mean": -1.2592146958922967e-05
    },
    "t2_stats": {
      "min": -0.44529813528060913,
      "max": 0.3085886240005493,
      "mean": -1.2624391274584923e-05
    }
  },
  "model.layers.0.mlp.up_proj.weight": {
    "mse": 2.5534721906361746e-11,
    "mae": 3.118609583907528e-06,
    "max_abs_diff": 0.00047551095485687256,
    "cosine_similarity": 1.038962483406067,
    "relative_error": 0.0013681650161743164,
    "kl_divergence_1_to_2": "-5.8253768e-08",
    "kl_divergence_2_to_1": "5.8747577e-08",
    "js_divergence": NaN,
    "shape": [
      18432,
      7168
    ],
    "t1_stats": {
      "min": -0.228515625,
      "max": 0.2333984375,
      "mean": 8.862222955485777e-08
    },
    "t2_stats": {
      "min": -0.2285017967224121,
      "max": 0.23338991403579712,
      "mean": 8.824501662729745e-08
    }
  },
  "model.layers.0.mlp.down_proj.weight": {
    "mse": 2.2803769289536646e-11,
    "mae": 2.8916260816913564e-06,
    "max_abs_diff": 0.0008973777294158936,
    "cosine_similarity": 1.0376262664794922,
    "relative_error": 0.001346255769021809,
    "kl_divergence_1_to_2": "1.2744896e-07",
    "kl_divergence_2_to_1": "-1.2736885e-07",
    "js_divergence": 5.992362162032805e-11,
    "shape": [
      7168,
      18432
    ],
    "t1_stats": {
      "min": -0.54296875,
      "max": 0.546875,
      "mean": -2.9487239316949854e-07
    },
    "t2_stats": {
      "min": -0.5429964661598206,
      "max": 0.5469087362289429,
      "mean": -2.9507478416235244e-07
    }
  }
}
```

https://www.internalfb.com/intern/testinfra/testrun/3940649985202645

Differential Revision: D82975005

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163532
Approved by: https://github.com/wwwjn
2025-09-23 16:48:16 +00:00
..
_experimental [dcp_poc] Fix parameter order in distributed checkpoint API to use path-first for consistency (#160986) 2025-08-20 04:09:18 +00:00
examples
__init__.py [DCP][HuggingFace] Add Support for dequantization of SafeTensors checkpoints (#160682) 2025-09-04 01:09:53 +00:00
_async_executor.py [DCP][OSS] Rank local checkpointing in DCP without collectives (#147758) 2025-08-13 16:20:28 +00:00
_async_process_executor.py [DCP] Add timeout for checkpoint background process join (#162828) 2025-09-16 02:32:50 +00:00
_async_thread_executor.py [DCP][OSS] Rank local checkpointing in DCP without collectives (#147758) 2025-08-13 16:20:28 +00:00
_checkpointer.py
_consolidate_hf_safetensors.py Add pg argument to consolidate_safetensors_files_on_every_rank (#161421) 2025-08-29 13:31:11 +00:00
_dedup_save_plans.py
_dedup_tensors.py
_extension.py
_fsspec_filesystem.py
_hf_utils.py Use only safetensors APIs in HFStorageReader (#159681) 2025-08-07 17:23:03 +00:00
_nested_dict.py
_pg_transport.py [DCP] Add support for ShardedTensor to PgTransport (#158573) 2025-07-21 21:04:23 +00:00
_sharded_tensor_utils.py
_state_dict_stager.py fix forced loglevel in pytorch oss code (#158820) 2025-07-24 00:40:28 +00:00
_storage_utils.py
_traverse.py
_version.py
api.py
default_planner.py [DCP][OSS] Rank local checkpointing in DCP without collectives (#147758) 2025-08-13 16:20:28 +00:00
filesystem.py [DCP][OSS] Rank local checkpointing in DCP without collectives (#147758) 2025-08-13 16:20:28 +00:00
format_utils.py
hf_storage.py [DCP][HF] Add option to parallelize reads in HF Storage Reader (#160205) 2025-08-21 23:58:02 +00:00
logger.py
logging_handlers.py
metadata.py [oss] Add version to metadata (#155343) 2025-07-07 20:57:30 +00:00
optimizer.py
planner_helpers.py
planner.py
quantized_hf_storage.py [DCP] DTensor slice dequantization with proper block alignment (#163532) 2025-09-23 16:48:16 +00:00
resharding.py
staging.py [doc]: Small typos (#162982) 2025-09-16 17:42:19 +00:00
state_dict_loader.py [DCP][OSS] Rank local checkpointing in DCP without collectives (#147758) 2025-08-13 16:20:28 +00:00
state_dict_saver.py [DCP] Avoid multiple storage writer resets in async save (#159448) 2025-09-10 00:43:03 +00:00
state_dict.py
stateful.py
storage.py [DCP][OSS] Rank local checkpointing in DCP without collectives (#147758) 2025-08-13 16:20:28 +00:00
utils.py [BE] add noqa for flake8 rule B036: found except BaseException without re-raising (#159043) 2025-07-25 02:56:34 +00:00