mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: When loading quantized tensors with DTensor slicing, the dequantization process was producing numerically incorrect results due to improper block-to-slice coordinate mapping. The previous implementation calculated block boundaries relative to the sliced tensor dimensions instead of the original full tensor dimensions, causing scale factors to be applied to wrong tensor regions. This fix addresses the issue by: 1. **Proper coordinate mapping**: Added `_get_slice_to_block_mapping()` to correctly map tensor slices to quantization blocks using global coordinates from the full tensor shape. 3. **Block-aligned dequantization**: Updated `_dequantize_tensor()` to use proper block intersection logic, ensuring scale factors are applied to the correct portions of sliced tensors. The fix ensures that when DTensor requests a slice of a quantized tensor, the dequantization correctly identifies which quantization blocks intersect with the requested slice and applies the appropriate scale factors to the right tensor regions. Test Plan: Tested with DTensor configurations where quantized tensors are sliced across different dimensions. Verified that: 1. Dequantized tensor values are numerically correct 2. Block boundaries are properly calculated relative to full tensor shape 3. Scale factors are applied to correct tensor regions 4. Tensor shapes map is built efficiently using only metadata Correctness validation using https://github.com/wwwjn/torchtitan/blob/dsv3-sd-test/tests/fsdp_dequantized_load.py ``` { "model.layers.0.mlp.gate_proj.weight": { "mse": 4.30626645453458e-11, "mae": 9.98388827611052e-07, "max_abs_diff": 0.0009703934192657471, "cosine_similarity": 1.010810375213623, "relative_error": 0.001330620958469808, "kl_divergence_1_to_2": "6.563401e-08", "kl_divergence_2_to_1": "-6.522914e-08", "js_divergence": 1.3711876079014476e-10, "shape": [ 18432, 7168 ], "t1_stats": { "min": -0.4453125, "max": 0.30859375, "mean": -1.2592146958922967e-05 }, "t2_stats": { "min": -0.44529813528060913, "max": 0.3085886240005493, "mean": -1.2624391274584923e-05 } }, "model.layers.0.mlp.up_proj.weight": { "mse": 2.5534721906361746e-11, "mae": 3.118609583907528e-06, "max_abs_diff": 0.00047551095485687256, "cosine_similarity": 1.038962483406067, "relative_error": 0.0013681650161743164, "kl_divergence_1_to_2": "-5.8253768e-08", "kl_divergence_2_to_1": "5.8747577e-08", "js_divergence": NaN, "shape": [ 18432, 7168 ], "t1_stats": { "min": -0.228515625, "max": 0.2333984375, "mean": 8.862222955485777e-08 }, "t2_stats": { "min": -0.2285017967224121, "max": 0.23338991403579712, "mean": 8.824501662729745e-08 } }, "model.layers.0.mlp.down_proj.weight": { "mse": 2.2803769289536646e-11, "mae": 2.8916260816913564e-06, "max_abs_diff": 0.0008973777294158936, "cosine_similarity": 1.0376262664794922, "relative_error": 0.001346255769021809, "kl_divergence_1_to_2": "1.2744896e-07", "kl_divergence_2_to_1": "-1.2736885e-07", "js_divergence": 5.992362162032805e-11, "shape": [ 7168, 18432 ], "t1_stats": { "min": -0.54296875, "max": 0.546875, "mean": -2.9487239316949854e-07 }, "t2_stats": { "min": -0.5429964661598206, "max": 0.5469087362289429, "mean": -2.9507478416235244e-07 } } } ``` https://www.internalfb.com/intern/testinfra/testrun/3940649985202645 Differential Revision: D82975005 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163532 Approved by: https://github.com/wwwjn |
||
|---|---|---|
| .. | ||
| _composable | ||
| _pycute | ||
| _shard | ||
| _tools | ||
| algorithms | ||
| bin | ||
| checkpoint | ||
| elastic | ||
| flight_recorder | ||
| fsdp | ||
| launcher | ||
| nn/jit | ||
| optim | ||
| pipelining | ||
| rpc | ||
| tensor | ||
| _test_template.py | ||
| argparse_util_test.py | ||
| test_backends.py | ||
| test_c10d_common.py | ||
| test_c10d_functional_native.py | ||
| test_c10d_gloo.py | ||
| test_c10d_logger.py | ||
| test_c10d_nccl.py | ||
| test_c10d_object_collectives.py | ||
| test_c10d_ops_nccl.py | ||
| test_c10d_pypg.py | ||
| test_c10d_spawn_gloo.py | ||
| test_c10d_spawn_nccl.py | ||
| test_c10d_spawn_ucc.py | ||
| test_c10d_spawn.py | ||
| test_c10d_ucc.py | ||
| test_collective_utils.py | ||
| test_composability.py | ||
| test_compute_comm_reordering.py | ||
| test_control_collectives.py | ||
| test_cupy_as_tensor.py | ||
| test_data_parallel.py | ||
| test_device_mesh.py | ||
| test_dist2.py | ||
| test_distributed_spawn.py | ||
| test_dynamo_distributed.py | ||
| test_fake_pg.py | ||
| test_functional_api.py | ||
| test_inductor_collectives.py | ||
| test_launcher.py | ||
| test_multi_threaded_pg.py | ||
| test_nccl.py | ||
| test_nvshmem_triton.py | ||
| test_nvshmem.py | ||
| test_p2p_ipc.py | ||
| test_pg_wrapper.py | ||
| test_run.py | ||
| test_serialization.py | ||
| test_store.py | ||
| test_symmetric_memory.py | ||