Supporting non-tensor-data write_size in planner write items. (#149434)

Summary:
1\ The current write item structure does not contain the amount of data that needs to be written.
2\ the planner.item already has a size primitive 'tensor_storage_size'. https://fburl.com/code/7a0gsmw7 But only for tensors.
3\ Right now, the only way the writer layer get hold of this property (fro non tensor data)

- first do a lookup in to the actual tensor/bytes
- then calculate the nbytes.
This change introduce a way to capture non-tensor data  size within a write-plan item.

Reviewed By: daulet-askarov

Differential Revision: D70497442

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149434
Approved by: https://github.com/MeetVadakkanchery
This commit is contained in:
Pradeep Fernando 2025-03-20 01:22:05 +00:00 committed by PyTorch MergeBot
parent 02e21c7854
commit 1442230a26

View File

@ -20,6 +20,7 @@ from torch.distributed.checkpoint.metadata import (
__all__ = [
"WriteItemType",
"LoadItemType",
"BytesIOWriteData",
"TensorWriteData",
"WriteItem",
"ReadItem",
@ -41,6 +42,11 @@ class LoadItemType(Enum):
BYTE_IO = auto()
@dataclass(frozen=True)
class BytesIOWriteData:
nbytes: int
@dataclass(frozen=True)
class TensorWriteData:
chunk: ChunkStorageMetadata
@ -55,6 +61,9 @@ class WriteItem:
index: MetadataIndex
type: WriteItemType
# Size of bytesIO data to be written.
bytes_io_data: Optional[BytesIOWriteData] = None
# Value present if it's a tensor write
tensor_data: Optional[TensorWriteData] = None