pytorch/torch/distributed/checkpoint/_hf_planner.py
Ankita George 8a40fca9a1 Support huggingface reading and writing for multi rank case (#148189)
Summary: This diff adds the ability for HF reader/writer to read/write in a distributed way. We do this by sending all the tensors meant for the same file to the same rank.

Test Plan:
ensure existing tests pass
I also ran a full end to end test on my devserver to read/write from my HF repo

Differential Revision: D70096439

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148189
Approved by: https://github.com/joecummings, https://github.com/saumishr
2025-03-26 14:47:31 +00:00

50 lines
1.5 KiB
Python

# mypy: allow-untyped-defs
from dataclasses import dataclass
from torch.distributed.checkpoint._dedup_save_plans import (
dedup_save_plans_with_fqn_to_index_mapping,
)
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.checkpoint.planner import ReadItem, SavePlan
__all__ = ["_HuggingFaceSavePlanner", "_HuggingFaceLoadPlanner"]
@dataclass
class _FqnToFileMapping:
fqn_to_file_index_mapping: dict[str, int]
class _HuggingFaceSavePlanner(DefaultSavePlanner):
"""
A save planner that dedups the save plans based on the fqn to file index mapping.
"""
def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]:
assert len(all_plans) > 0, "all_plans should not be empty"
assert all_plans[0].storage_data is not None, "storage_data should not be None"
assert isinstance(all_plans[0].storage_data, _FqnToFileMapping), (
"storage_data should be of type _FqnToFileMapping"
)
fqn_to_index_mapping: dict[str, int] = all_plans[
0
].storage_data.fqn_to_file_index_mapping
return dedup_save_plans_with_fqn_to_index_mapping(
all_plans, fqn_to_index_mapping
)
class _HuggingFaceLoadPlanner(DefaultLoadPlanner):
def __init__(self, allow_tensor_resize: bool = False):
super().__init__()
self.allow_tensor_resize = allow_tensor_resize
def resolve_tensor(self, read_item: ReadItem):
return self.lookup_tensor(read_item.dest_index)