[checkpoint] Adopt Planner interface across the board. (#83781)

Change StorageReader and StorageWriter to follow the new SavePlanner / LoadPlanner design.

Add optional planner param to load_state_dict and save_state_dict and implement the new protocol.

This includes a small rework of FileSystem layer to support single file per rank and making fsync optional to match torch.save behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83781
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
This commit is contained in:
Rodrigo Kumpera 2022-08-29 14:38:32 +00:00 committed by PyTorch MergeBot
parent fbf5a3f9f4
commit f66be71d77
11 changed files with 600 additions and 836 deletions

View File

@ -1,8 +1,9 @@
# Owner(s): ["oncall: distributed"]
import random
import sys
from typing import Optional, List, Union, cast
from typing import Optional, List, cast
from torch.distributed._shard.checkpoint.storage import WriteResult
from torch.distributed._shard.checkpoint import (
StorageReader,
StorageWriter,
@ -16,26 +17,24 @@ import torch.distributed as dist
import torch.nn
import torch.futures
from torch.futures import Future
from torch.testing._internal.common_utils import TestCase
from torch.distributed._shard.checkpoint.resharding import (
_prepare_sharded_tensor_write,
_create_storage_key
)
from torch.distributed._shard import sharded_tensor
from torch.distributed._shard.checkpoint.state_dict_saver import (
_prepare,
from torch.distributed._shard.checkpoint.default_planner import (
_create_default_local_metadata,
)
from torch.distributed._shard.checkpoint.metadata import (
BytesStorageMetadata,
Metadata,
BytesReadRequest,
BytesWriteRequest,
MetadataIndex,
TensorReadRequest,
TensorWriteRequest,
TensorStorageMetadata,
)
from torch.distributed._shard.checkpoint.planner import (
SavePlan,
SavePlanner,
LoadPlan,
LoadPlanner,
)
from torch.distributed._shard.sharded_tensor import (
@ -89,17 +88,6 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
def world_size(self) -> int:
return 2
def gen_metadata(self) -> Metadata:
module = TestModule()
# compute the default saved metadata (must pass include_non_replicated_tensors or we'll get incomplete MD)
metadata, _, _ = _prepare(module.state_dict(), True)
# _prepare only produc
metadata = [metadata]
dist.broadcast_object_list(metadata)
return metadata[0]
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
@ -114,15 +102,15 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
mapping = {}
(_, md, storage_md) = _prepare_sharded_tensor_write("fqn", st, "tensor", mapping)
md = _create_default_local_metadata({"st": st})
self.assertEqual(1, len(storage_md))
self.assertEqual(1, len(mapping))
st_md = md.state_dict_metadata["st"]
self.assertEqual(1, len(st_md.chunks))
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_storage_key_mapping(self) -> None:
def test_default_metadata(self) -> None:
device = f"cuda:{dist.get_rank()}"
spec = ChunkShardingSpec(
dim=0,
@ -138,48 +126,23 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
'bytes': [1, 2, 3, 4],
}
metadata, bytes_reqs, tensor_reqs = _prepare(state_dict, write_replicated_data=self.rank == 0)
if self.rank == 0:
self.assertEqual(1, len(bytes_reqs))
self.assertEqual(2, len(tensor_reqs))
self.assertTrue('bytes' in metadata.state_dict_metadata)
self.assertTrue(MetadataIndex('bytes') in metadata.storage_data)
# tensor ordering is unspecified
if len(tensor_reqs[0].tensor.size()) == 1:
replicated = tensor_reqs[0]
shard = tensor_reqs[1]
else:
replicated = tensor_reqs[1]
shard = tensor_reqs[0]
self.assertTrue('replicated' in metadata.state_dict_metadata)
storage_key = MetadataIndex('replicated', torch.Size([0]))
self.assertTrue(storage_key in metadata.storage_data)
self.assertTrue(metadata.storage_data[storage_key], replicated.storage_key)
else:
self.assertEqual(0, len(bytes_reqs))
self.assertEqual(1, len(tensor_reqs))
shard = tensor_reqs[0]
local_shard = state_dict["sharded"].local_shards()[0]
self.assertTrue('sharded' in metadata.state_dict_metadata)
storage_key = MetadataIndex('sharded', torch.Size(local_shard.metadata.shard_offsets))
self.assertTrue(storage_key in metadata.storage_data)
self.assertTrue(metadata.storage_data[storage_key], shard.storage_key)
class TestStorageKeys(TestCase):
def test_create_key_handles_collision(self):
keys = {}
key0 = _create_storage_key(keys, "foo")
key1 = _create_storage_key(keys, "foo")
self.assertNotEqual(key0, key1)
metadata = _create_default_local_metadata(state_dict)
self.assertTrue('bytes' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['bytes'], BytesStorageMetadata)
self.assertTrue('replicated' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['replicated'], TensorStorageMetadata)
md = metadata.state_dict_metadata['replicated']
self.assertEqual(md.size, state_dict['replicated'].size())
self.assertEqual(md.properties.dtype, torch.float32)
self.assertEqual(1, len(md.chunks))
self.assertTrue('sharded' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['sharded'], TensorStorageMetadata)
md = metadata.state_dict_metadata['sharded']
self.assertEqual(md.properties.dtype, torch.float32)
self.assertEqual(md.size, state_dict['sharded'].size())
self.assertEqual(2, len(md.chunks))
class TestStorageBase:
def __init__(
@ -197,16 +160,15 @@ class TestStorageBase:
if ranks is not None and self.rank in ranks:
raise ValueError(f"rank fail {self.rank} for {name}")
def _fail_rank_async(self, name):
def _fail_rank_async(self, name, result=None):
ranks = self._get_ranks(name)
fut = Future()
if ranks is not None and self.rank in ranks:
fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
else:
fut.set_result(None)
fut.set_result(result)
return fut
class FaultyStorageWriter(TestStorageBase, StorageWriter):
def __init__(
self,
@ -214,22 +176,28 @@ class FaultyStorageWriter(TestStorageBase, StorageWriter):
):
super(FaultyStorageWriter, self).__init__(fail_conf)
def prepare(self) -> None:
self._fail_rank("fail_prepare")
def init(self, is_coordinator: bool) -> None:
self._fail_rank("fail_init")
def write_bytes(self, requests: List[BytesWriteRequest]) -> Future[None]:
self._fail_rank("fail_write_bytes_on_ranks")
return self._fail_rank_async("fail_write_bytes_on_ranks_async")
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
self._fail_rank("fail_prepare_local_plan")
return plan
def write_tensors(self, requests: List[TensorWriteRequest]) -> Future[None]:
self._fail_rank("fail_write_tensors_on_ranks")
return self._fail_rank_async("fail_write_tensors_on_ranks_async")
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
self._fail_rank("fail_prepare_global_plan")
return plans
def finish(self, metadata: Metadata) -> None:
def write_data(
self,
plan: SavePlan,
planner: SavePlanner
) -> Future[List[WriteResult]]:
self._fail_rank("fail_write_data")
return self._fail_rank_async("fail_write_data_async", [])
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
self._fail_rank("fail_finish")
def prepare_storage(self, storage_writes: List[Union[TensorWriteRequest, BytesWriteRequest]]) -> None:
self._fail_rank("fail_prepare_storage")
class FaultyStorageReader(TestStorageBase, StorageReader):
def __init__(
@ -240,22 +208,24 @@ class FaultyStorageReader(TestStorageBase, StorageReader):
super(FaultyStorageReader, self).__init__(fail_conf)
self.metadata = metadata
def read_bytes(self, requests: List[BytesReadRequest]) -> Future[None]:
self._fail_rank("fail_read_bytes")
bad_ranks = self._get_ranks("fail_deser_bytes")
for r in requests:
if bad_ranks is not None and self.rank in bad_ranks:
# this is not "guaranteed" to fail, but hard to beat
rand = random.Random(1237)
r.bytes.write(rand.randbytes(32))
else:
torch.save([1, 2, 3], r.bytes)
def init(self, metadata: Metadata, is_coordinator: bool) -> None:
self._fail_rank("fail_init")
return self._fail_rank_async("fail_read_bytes_async")
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
self._fail_rank("fail_prepare_local_plan")
return plan
def read_tensors(self, requests: List[TensorReadRequest]) -> Future[None]:
self._fail_rank("fail_read_tensors")
return self._fail_rank_async("fail_read_tensors_async")
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
self._fail_rank("fail_prepare_global_plan")
return plans
def read_data(
self,
plan: LoadPlan,
planner: LoadPlanner
) -> Future[None]:
self._fail_rank("fail_read_data")
return self._fail_rank_async("fail_read_data_async")
def read_metadata(self) -> Metadata:
self._fail_rank("fail_read_metadata")
@ -282,6 +252,18 @@ class TestDistributedFailure(ShardedTensorTestBase):
save_state_dict(state_dict, FaultyStorageWriter({}))
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_dummy_reader_works(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
}
metadata = _create_default_local_metadata(state_dict)
load_state_dict(state_dict, FaultyStorageReader(metadata, {}))
def _test_dist_failure(self, callback, kwargs):
bad_ranks = list(kwargs.values())[0] if len(kwargs) > 0 else []
@ -318,10 +300,9 @@ class TestDistributedFailure(ShardedTensorTestBase):
def _test_load(self, state_dict, coordinator=0, **kwargs):
no_dist = not dist.is_initialized()
write_replicated = dist.is_initialized() and dist.get_rank() == coordinator
def _load():
metadata, _, _ = _prepare(state_dict, write_replicated)
metadata = _create_default_local_metadata(state_dict)
load_state_dict(
state_dict,
storage_reader=FaultyStorageReader(metadata, kwargs),
@ -341,21 +322,17 @@ class TestDistributedFailure(ShardedTensorTestBase):
'bytes': [1, 2, 3, 4]
}
self._test_save(state_dict, fail_prepare=[0])
self._test_save(state_dict, fail_init=[0])
self._test_save(state_dict, fail_finish=[0])
self._test_save(state_dict, fail_prepare_global_plan=[0])
self._test_save(state_dict, fail_prepare_storage=[0])
self._test_save(state_dict, fail_write_tensors_on_ranks=[1])
self._test_save(state_dict, fail_write_tensors_on_ranks_async=[2])
self._test_save(state_dict, fail_write_bytes_on_ranks=[3])
self._test_save(state_dict, fail_write_bytes_on_ranks_async=[1])
self._test_save(state_dict, fail_prepare_local_plan=[0])
self._test_save(state_dict, fail_write_data=[2])
self._test_save(state_dict, fail_write_data_async=[3])
self._test_save(state_dict, fail_write_tensors_on_ranks_async=[1, 3])
self._test_save(state_dict, coordinator=1, fail_prepare=[1])
self._test_save(state_dict, coordinator=1, fail_init=[1])
self._test_save(state_dict, coordinator=1, fail_finish=[1])
def test_save_error_handling_no_dist(self) -> None:
state_dict = {
'replicated': torch.rand(10, 10),
@ -364,14 +341,13 @@ class TestDistributedFailure(ShardedTensorTestBase):
self.assertFalse(dist.is_initialized())
self._test_save(state_dict, fail_prepare=[0])
self._test_save(state_dict, fail_init=[0])
self._test_save(state_dict, fail_finish=[0])
self._test_save(state_dict, fail_prepare_global_plan=[0])
self._test_save(state_dict, fail_prepare_storage=[0])
self._test_save(state_dict, fail_write_tensors_on_ranks=[0])
self._test_save(state_dict, fail_write_tensors_on_ranks_async=[0])
self._test_save(state_dict, fail_write_bytes_on_ranks=[0])
self._test_save(state_dict, fail_write_bytes_on_ranks_async=[0])
self._test_save(state_dict, fail_prepare_local_plan=[0])
self._test_save(state_dict, fail_write_data=[0])
self._test_save(state_dict, fail_write_data_async=[0])
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@ -384,17 +360,18 @@ class TestDistributedFailure(ShardedTensorTestBase):
}
self._test_load(state_dict)
self._test_load(state_dict, fail_init=[0])
self._test_load(state_dict, fail_prepare_global_plan=[0])
self._test_load(state_dict, fail_read_metadata=[0])
self._test_load(state_dict, fail_read_bytes=[1])
self._test_load(state_dict, fail_read_bytes_async=[2])
self._test_load(state_dict, fail_read_tensors=[3])
self._test_load(state_dict, fail_read_tensors_async=[1])
# We don't want to depend on the actual exception raised by pickle
self._test_load(state_dict, fail_deser_bytes=[2], ignore_exception_type=True)
self._test_load(state_dict, fail_prepare_local_plan=[1])
self._test_load(state_dict, fail_read_data=[3])
self._test_load(state_dict, fail_read_data_async=[1])
self._test_load(state_dict, coordinator=3, fail_init=[0])
self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
self._test_load(state_dict, coordinator=2, fail_read_bytes=[0])
self._test_load(state_dict, coordinator=3, fail_read_tensors_async=[2])
self._test_load(state_dict, coordinator=2, fail_read_data=[0])
self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])
def test_load_error_handling_no_dist(self) -> None:
@ -403,11 +380,12 @@ class TestDistributedFailure(ShardedTensorTestBase):
'bytes': [1, 2, 3, 4]
}
self._test_load(state_dict)
self._test_load(state_dict, fail_init=[0])
self._test_load(state_dict, fail_read_metadata=[0])
self._test_load(state_dict, fail_read_bytes=[0])
self._test_load(state_dict, fail_read_bytes_async=[0])
self._test_load(state_dict, fail_read_tensors=[0])
self._test_load(state_dict, fail_read_tensors_async=[0])
self._test_load(state_dict, fail_deser_bytes=[0], ignore_exception_type=True)
self._test_load(state_dict, fail_prepare_local_plan=[0])
self._test_load(state_dict, fail_prepare_global_plan=[0])
self._test_load(state_dict, fail_read_data=[0])
self._test_load(state_dict, fail_read_data_async=[0])
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]
import sys
import os
import shutil
import tempfile
@ -119,6 +118,23 @@ class TestDistributedStateDictSaveLoad(TestCase):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
with tempfile.TemporaryDirectory() as path:
state_dict_to_save = MyTestModule().state_dict()
fs_writer = FileSystemWriter(path=path, single_file_per_rank=True)
save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer, no_dist=True)
state_dict_to_load_to = MyTestModule().state_dict()
with self.assertRaises(AssertionError):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader, no_dist=True)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@property

View File

@ -116,6 +116,7 @@ class TestSavePlan(TestCase):
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
# The default global plan updates all indexes to include hints
for new_plan, old_plan in zip(final_plans, all_plans):
for new_item, old_item in zip(new_plan.items, old_plan.items):

View File

@ -1,12 +1,8 @@
from .metadata import (
BytesReadRequest,
BytesWriteRequest,
TensorStorageMetadata,
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
TensorReadRequest,
TensorWriteRequest,
)
from .state_dict_loader import load_state_dict
from .state_dict_saver import save_state_dict

View File

@ -1,7 +1,9 @@
from dataclasses import dataclass
import os
import operator
import dataclasses
import io
import pickle
from typing import List, Optional, Union, cast
from typing import List, Union, Dict, cast
import torch
from torch import Tensor
@ -9,16 +11,100 @@ from torch.futures import Future
from pathlib import Path
from .metadata import (
BytesReadRequest,
BytesWriteRequest,
Metadata,
TensorReadRequest,
TensorWriteRequest,
MetadataIndex,
)
from .storage import StorageReader, StorageWriter
from .storage import (
StorageReader,
StorageWriter,
WriteResult,
)
from .planner import (
LoadItemType,
LoadPlanner,
LoadPlan,
SavePlan,
SavePlanner,
ReadItem,
WriteItem,
WriteItemType,
)
from torch.distributed._shard._utils import narrow_tensor_by_index
@dataclass
class _StorageInfo:
"""
This is the per entry storage info
"""
relative_path: str
offset: int
length: int
@dataclass
class _StoragePrefix:
prefix: str
DEFAULT_SUFIX = ".distcp"
def _trim(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach().cpu()
if tensor.storage().size() != tensor.numel():
tensor = tensor.clone()
return tensor
def _result_from_write_item(item: WriteItem, size_in_bytes, storage_data) -> WriteResult:
return WriteResult(
index=item.index,
size_in_bytes=size_in_bytes,
storage_data=storage_data)
def _write_item(stream, data, write_item, storage_key):
offset = stream.tell()
if write_item.type == WriteItemType.BYTE_IO:
assert isinstance(data, io.BytesIO)
stream.write(data.getbuffer())
else:
assert isinstance(data, torch.Tensor)
assert data.device == torch.device("cpu")
torch.save(data, stream)
length = stream.tell() - offset
return _result_from_write_item(
write_item,
length,
_StorageInfo(storage_key, offset, length)
)
def _write_files_from_queue(
file_queue: List,
planner: SavePlanner,
use_fsync: bool,
):
write_results = []
for file_path, file_name, write_items in file_queue:
tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
with open(file_path, "wb") as stream:
for write_item in bytes_w:
data = planner.resolve_data(write_item)
write_results.append(_write_item(stream, data, write_item, file_name))
for write_item in tensor_w:
tensor = _trim(cast(torch.Tensor, planner.resolve_data(write_item)))
assert not tensor.is_cuda
write_results.append(_write_item(stream, tensor, write_item, file_name))
if use_fsync:
os.fsync(stream.fileno())
return write_results
class FileSystemWriter(StorageWriter):
"""
Basic implementation of StorageWriter using file IO.
@ -32,108 +118,152 @@ class FileSystemWriter(StorageWriter):
a `.metadata` file with the serialized metadata.
"""
def __init__(self, path: Union[str, os.PathLike]) -> None:
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = False,
sync_files: bool = True,
) -> None:
"""
Initialize the writer pointing to `path`
Args:
path: diretory where the checkpoint will be writen to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files: force files to be synced to permanent storage. Default to True.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__()
self.path = Path(path)
self.single_file_per_rank = single_file_per_rank
self.sync_files = sync_files
def write_bytes(self, requests: List[BytesWriteRequest]) -> Future[None]:
for req in requests:
with (self.path / req.storage_key).open("wb") as w:
w.write(req.bytes.getbuffer())
os.fsync(w.fileno())
def init(self, is_coordinator: bool) -> None:
pass
fut: Future[None] = Future()
fut.set_result(None)
return fut
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
# There's no storage input in the local plan
return plan
def write_tensors(self, requests: List[TensorWriteRequest]) -> Future[None]:
for req in requests:
# The following couple lines are simple implementation to get
# things going.
#
# At load time, to enable resharding, we use (sub)view of the tensor.
# Since the storage of the tensor might not be contiguous. we need to
# preserve the original view, to calculate the correct sub view at load.
#
# `torch.save` saves both the view and storage, it is a good option
# for unblocking. There are two drawbacks:
# 1. `torch.save` is pickle based, and pickle is not known for its
# compatibility, we should consider replacing it with a more
# stable option.
# 2. pickle is not streamable.
with (self.path / req.storage_key).open("wb") as w:
torch.save(req.tensor, w)
os.fsync(w.fileno())
fut: Future[None] = Future()
fut.set_result(None)
return fut
def prepare(self) -> None:
def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
self.path.mkdir(parents=True, exist_ok=True)
def finish(self, metadata: Metadata) -> None:
new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan)
]
return new_plans
def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
) -> Future[List[WriteResult]]:
storage_plan: _StoragePrefix = plan.storage_data
file_count = 0
def gen_file():
nonlocal file_count
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFIX}"
file_count += 1
return file_name
file_queue = []
if self.single_file_per_rank:
file_name = gen_file()
file_queue.append((self.path / file_name, file_name, plan.items))
else:
for item in plan.items:
file_name = gen_file()
file_queue.append((self.path / file_name, file_name, [item]))
results = _write_files_from_queue(
file_queue=file_queue,
planner=planner,
use_fsync=self.sync_files,
)
fut: Future[List[WriteResult]] = Future()
fut.set_result(results)
return fut
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
storage_md = dict()
for wr_list in results:
storage_md.update({
wr.index: wr.storage_data for wr in wr_list
})
metadata.storage_data = storage_md
with (self.path / ".metadata.tmp").open("wb") as metadata_file:
pickle.dump(metadata, metadata_file)
os.fsync(metadata_file.fileno())
(self.path / ".metadata.tmp").rename(self.path / ".metadata")
class SlicedBufferedReader(io.BufferedReader):
# TODO override read to handle (-1) correctly
def __init__(self, base_stream: io.RawIOBase, offset: int, len: int):
super().__init__(base_stream)
self.offset = offset
self.len = len
self.seek(0)
def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
if __whence == os.SEEK_SET:
__offset = self.offset + __offset
elif __whence == os.SEEK_END:
__whence = os.SEEK_SET
__offset = (self.offset + self.len) - __offset
return super().seek(__offset, __whence)
def tell(self) -> int:
return super().tell() - self.offset
class FileSystemReader(StorageReader):
def __init__(self, path: Union[str, os.PathLike]) -> None:
super().__init__()
self.path = Path(path)
self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
def read_tensors(self, requests: List[TensorReadRequest]) -> Future[None]:
"""
Very basic implementation that read from file system.
"""
# Sort the the requests by storage key and try to reuse the loaded tensors
requests.sort(key=operator.attrgetter("storage_key"))
def _slice_file(self, file, sinfo: _StorageInfo):
return SlicedBufferedReader(
io.FileIO(file.fileno(), closefd=False),
sinfo.offset, sinfo.length
)
cached_storage_key = None
view_cached: Optional[Tensor] = None
def read_data(
self,
plan: LoadPlan,
planner: LoadPlanner
) -> Future[None]:
# group requests by file
per_file: Dict[str, List[ReadItem]] = dict()
for read_item in plan.items:
item_md = self.storage_data[read_item.storage_index]
path = item_md.relative_path
per_file.setdefault(path, []).append(read_item)
for req in requests:
if cached_storage_key != req.storage_key or \
(view_cached is not None and view_cached.device != req.tensor.device):
for relative_path, reqs in per_file.items():
with (self.path / relative_path).open("rb") as file:
# TODO sort by offset and cache the reading
for req in reqs:
item_md = self.storage_data[req.storage_index]
file_slice = self._slice_file(file, item_md)
if req.type == LoadItemType.BYTE_IO:
bytes = io.BytesIO(file_slice.read(item_md.length))
bytes.seek(0)
planner.load_bytes(req, bytes)
else:
tensor = cast(Tensor, torch.load(file_slice, map_location="cpu"))
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
target_tensor = planner.resolve_tensor(req).detach()
with (self.path / req.storage_key).open("rb") as storage:
view_cached = cast(Tensor, torch.load(storage, map_location=req.tensor.device))
cached_storage_key = req.storage_key
view_to_copy: Tensor = cast(Tensor, view_cached)
# FileSystemWrite writes the tensor as is during save.
# During load time, we will load the Tensor (with it orignal view)
# narrow it along all dimemsions, and copy_ it to the
# target tensor, which will be the same size.
view_to_copy = narrow_tensor_by_index(view_to_copy, req.offsets, req.lengths)
assert (
view_to_copy.size() == req.tensor.size()
), f"The {req.storage_key} src/dst size does not match."
assert (
view_to_copy.device == req.tensor.device
), f"cannot load across devices {view_to_copy.device} vs {req.tensor.device}"
req.tensor.copy_(view_to_copy)
fut: Future = Future()
fut.set_result(None)
return fut
def read_bytes(self, requests: List[BytesReadRequest]) -> Future[None]:
for req in requests:
with (self.path / req.storage_key).open("rb") as storage:
req.bytes.write(storage.read())
assert (
target_tensor.size() == tensor.size()
), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
fut: Future = Future()
fut.set_result(None)
@ -143,3 +273,13 @@ class FileSystemReader(StorageReader):
def read_metadata(self) -> Metadata:
with (self.path / ".metadata").open("rb") as metadata_file:
return pickle.load(metadata_file)
def init(self, metadata: Metadata, is_coordinator: bool) -> None:
self.storage_data = metadata.storage_data
assert self.storage_data is not None
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
return global_plan

View File

@ -1,12 +1,11 @@
import io
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union, Optional, Sequence, Any
from typing import Dict, List, Union, Optional, Sequence, Any
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
import torch
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
@dataclass
class ChunkStorageMetadata:
@ -37,34 +36,6 @@ class Metadata:
planner_data: Any = None
storage_data: Any = None
@dataclass
class BytesWriteRequest:
bytes: io.BytesIO
storage_key: str
@dataclass
class BytesReadRequest:
bytes: io.BytesIO
storage_key: str
fqn: str
@dataclass
class TensorWriteRequest:
tensor: torch.Tensor
storage_key: str
@dataclass
class TensorReadRequest:
tensor: torch.Tensor
storage_key: str
# offset and length w.r.t. to the storage identified by ``storage_key``
offsets: Tuple[int, ...]
lengths: Tuple[int, ...]
@dataclass(frozen=True)
class MetadataIndex:
"""

View File

@ -1,61 +1,8 @@
import hashlib
import io
from typing import List, Tuple, Dict
from typing import List, Tuple
import torch
from torch import Tensor
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import (
ShardMetadata,
)
from torch.distributed._shard.sharding_spec._internals import (
_check_shard_metadata_pair_overlap,
)
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from .metadata import (
BytesWriteRequest,
TensorReadRequest,
TensorWriteRequest,
ChunkStorageMetadata,
TensorStorageMetadata,
BytesStorageMetadata,
MetadataIndex,
)
def _trim(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach()
if tensor.storage().size() != tensor.numel():
return tensor.clone()
return tensor
def _create_storage_key(
storage_key_to_fqn: Dict[str, str],
fqn: str
) -> str:
"""
Compute the storage key from the Fully Qualified Name
Storage keys must respect the following properties:
1) Globally unique name across all objects and ranks.
2) Suitable for usage with common storage systems (IE, alphanumeric only)
"""
storage_key = hashlib.sha256(bytes(fqn, "utf-8")).hexdigest()
counter = 0
while storage_key in storage_key_to_fqn:
storage_key = hashlib.sha256(bytes(f"{fqn}{counter}", "utf-8")).hexdigest()
counter += 1
storage_key_to_fqn[storage_key] = fqn
return storage_key
# This constant is used as the separator character between tensor name and shard name
STORAGE_KEY_SEPARATOR = "$"
def _shards_get_overlap_region_wrt_saved_tensor(
saved_shard: ShardMetadata, current_shard: ShardMetadata
@ -101,206 +48,3 @@ def _shards_get_overlap_region_wrt_saved_tensor(
)
return narrows
def _chunk_to_shard_md(chunk_md: ChunkStorageMetadata) -> ShardMetadata:
return ShardMetadata(
shard_offsets=list(chunk_md.offsets),
shard_sizes=list(chunk_md.sizes)
)
def _shard_md_to_chunk(chunk_md: ShardMetadata) -> ChunkStorageMetadata:
return ChunkStorageMetadata(
offsets=torch.Size(chunk_md.shard_offsets),
sizes=torch.Size(chunk_md.shard_sizes),
)
def _compute_sharded_tensor_md(
tensor: ShardedTensor,
) -> TensorStorageMetadata:
smd = [_shard_md_to_chunk(sm) for sm in tensor.metadata().shards_metadata]
return TensorStorageMetadata(
properties=tensor.metadata().tensor_properties,
size=torch.Size(tensor.metadata().size),
chunks=smd,
)
def _get_shard_key(shard: ShardMetadata) -> str:
"""
Compute an unique key for a shard.
This key is unique vis-a-vis other shard of the owning ShardedTensor
"""
return "_".join(str(i) for i in shard.shard_offsets)
def _get_shard_storage_key(
tensor_storage_key: str,
shard: ShardMetadata,
storage_key_to_fqn: Dict[str, str]
) -> str:
shard_key = f"{tensor_storage_key}{STORAGE_KEY_SEPARATOR}{_get_shard_key(shard)}"
return _create_storage_key(storage_key_to_fqn, shard_key)
def _prepare_sharded_tensor_write(
fqn: str,
sharded_tensor: ShardedTensor,
storage_key: str,
storage_key_to_fqn: Dict[str, str]
) -> Tuple[List[TensorWriteRequest], TensorStorageMetadata, Dict[MetadataIndex, str]]:
"""
Prepare sharded tensor write.
Args:
fqn: The FQN of ``sharded_tensor`` in the state_dict.
sharded_tensor: The sharded tensor to persist.
storage_key: The identifier for `sharded_tensor`.
storage_key_to_fqn: dict used to produce storage keys
Returns:
A 3-element tuple with the following values:
List of ``TensorWriteRequest`` for the tensor
Metadada describing the tensor.
Dictionary describing storage information for this tensor
NB `storage_key` is used to compose the key names of the local shards.
"""
write_requests = []
shard_to_storage_key: Dict[str, str] = {}
storage_md = {}
for shard_md in sharded_tensor.metadata().shards_metadata:
shard_storage_key = _get_shard_storage_key(storage_key, shard_md, storage_key_to_fqn)
shard_to_storage_key[_get_shard_key(shard_md)] = shard_storage_key
storage_md[MetadataIndex(fqn, shard_md.shard_offsets)] = shard_storage_key
for shard in sharded_tensor.local_shards():
tensor = shard.tensor.detach()
shard_storage_key = shard_to_storage_key[_get_shard_key(shard.metadata)]
wr = TensorWriteRequest(
tensor=_trim(tensor),
storage_key=shard_storage_key,
)
write_requests.append(wr)
return write_requests, _compute_sharded_tensor_md(sharded_tensor), storage_md
def _prepare_sharded_tensor_read(
fqn: str,
storage_metadata: Dict[MetadataIndex, str],
metadata: TensorStorageMetadata,
sharded_tensor_out: ShardedTensor
) -> List[TensorReadRequest]:
"""
Prepare sharded tensor read.
Args:
fqn: The FQN of ``sharded_tensor`` in the state_dict.
storage_metadata: Dictionary describing checkpoint storage.
metadata: Metadata describing the persisted sharded tensor. Normally,
this is generated by func::`_prepare_sharded_tensor_write`.
sharded_tensor_out: The ShardedTensor being read.
Returns:
A list of class::`TensorReadRequest`. When fullfilled,
`sharded_tensor_out`'s local shards load from the persisted sharded
tensor.
"""
return _prepare_generic_tensor_read(
fqn,
metadata.chunks,
sharded_tensor_out.local_shards(),
storage_metadata)
def _prepare_generic_tensor_read(
fqn: str,
checkpoint_shards: List[ChunkStorageMetadata],
local_shards: List[Shard],
storage_metadata: Dict[MetadataIndex, str]
) -> List[TensorReadRequest]:
read_reqs = []
# this is a naive quadratic algo that can be optimized later
for shard in local_shards:
# scan all mds looking for chunks
for storage_md in checkpoint_shards:
shard_md_from_storage = _chunk_to_shard_md(storage_md)
# do they overlap?
if not _check_shard_metadata_pair_overlap(
shard.metadata, shard_md_from_storage
):
continue
storage_key = storage_metadata[MetadataIndex(fqn, storage_md.offsets)]
target_tensor = shard.tensor.detach()
offsets = []
lengths = []
for (
dim,
offset_for_saved_tensor,
offset_for_current_tensor,
length,
) in _shards_get_overlap_region_wrt_saved_tensor(
saved_shard=shard_md_from_storage, current_shard=shard.metadata
):
# Note that we do NOT want to make any tensor copy.
# all operation must be view only
target_tensor = torch.narrow(
target_tensor, dim, offset_for_current_tensor, length
)
offsets.append(offset_for_saved_tensor)
lengths.append(length)
read_reqs.append(
TensorReadRequest(
tensor=target_tensor,
storage_key=storage_key,
offsets=tuple(offsets),
lengths=tuple(lengths),
)
)
return read_reqs
def _compute_tensor_md(tensor: Tensor) -> TensorStorageMetadata:
return TensorStorageMetadata(
properties=TensorProperties.create_from_tensor(tensor),
size=tensor.size(),
chunks=[ChunkStorageMetadata(
offsets=torch.Size([0] * len(tensor.shape)),
sizes=tensor.size(),
)]
)
def _prepare_tensor_write(
tensor: Tensor, fqn: str, storage_key_to_fqn: Dict[str, str]
) -> Tuple[List[TensorWriteRequest], TensorStorageMetadata, Dict[MetadataIndex, str]]:
storage_key = _create_storage_key(storage_key_to_fqn, fqn)
storage_md = {MetadataIndex(fqn, [0] * len(tensor.shape)): storage_key}
write_reqs = [
TensorWriteRequest(
tensor=_trim(tensor),
storage_key=storage_key,
)
]
return (write_reqs, _compute_tensor_md(tensor), storage_md)
def _compute_bytes_md(bytes: io.BytesIO) -> BytesStorageMetadata:
return BytesStorageMetadata(
)
def _prepare_bytes_write(
bytes: io.BytesIO, fqn: str, storage_key_to_fqn: Dict[str, str]
) -> Tuple[List[BytesWriteRequest], BytesStorageMetadata, Dict[MetadataIndex, str]]:
storage_key = _create_storage_key(storage_key_to_fqn, fqn)
storage_md = {MetadataIndex(fqn): storage_key}
write_reqs = [
BytesWriteRequest(
bytes=bytes,
storage_key=storage_key,
)
]
return (write_reqs, _compute_bytes_md(bytes), storage_md)

View File

@ -1,96 +1,22 @@
import io
from typing import Any, Dict, List, Tuple, Optional, cast
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor.shard import Shard
from typing import Any, Dict, Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from .metadata import (
BytesReadRequest,
BytesStorageMetadata,
TensorReadRequest,
TensorStorageMetadata,
Metadata,
MetadataIndex,
)
from .resharding import (
_prepare_generic_tensor_read,
)
from .storage import (
StorageReader,
)
from .planner import LoadPlanner
from .default_planner import DefaultLoadPlanner
from .utils import _DistWrapper
def _create_shard_metadata(size: torch.Size) -> ShardMetadata:
return ShardMetadata(
shard_offsets=[0] * len(size),
shard_sizes=list(size),
)
def _create_shard_for(tensor: Tensor) -> Shard:
return Shard(
tensor=tensor,
metadata=_create_shard_metadata(tensor.size()),
)
def _reshard_and_prepare_read_request(
state_dict: Dict[str, Any], metadata_from_storage: Metadata
) -> Tuple[List[BytesReadRequest], List[TensorReadRequest]]:
"""
Use the loaded metadata and the current state dict to map the saved tensors to current tensor
"""
tensor_read_requests = []
bytes_read_requests = []
storage_md = cast(Dict[MetadataIndex, str], metadata_from_storage.storage_data)
for fqn, obj in state_dict.items():
md = metadata_from_storage.state_dict_metadata[fqn]
if isinstance(obj, ShardedTensor):
local_shards = obj.local_shards()
elif isinstance(obj, torch.Tensor):
local_shards = [_create_shard_for(obj)]
else:
if isinstance(md, BytesStorageMetadata):
bytes_io = io.BytesIO()
brr = BytesReadRequest(
bytes=bytes_io,
storage_key=storage_md[MetadataIndex(fqn)],
fqn=fqn
)
bytes_read_requests.append(brr)
else:
raise ValueError(
f"Invalid checkpoint metadata for {fqn}, " +
f"expected BytesStorageMetadata but found {type(md)}"
)
continue
if isinstance(md, TensorStorageMetadata):
checkpoint_shards = md.chunks
else:
raise ValueError(
f"Invalid checkpoint metadata for {fqn}, " +
f"expected TensorStorageMetadata but found {type(md)}"
)
tensor_read_requests += _prepare_generic_tensor_read(fqn, checkpoint_shards, local_shards, storage_md)
return (bytes_read_requests, tensor_read_requests)
def load_state_dict(
state_dict: Dict[str, Any],
storage_reader: StorageReader,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False
no_dist: bool = False,
planner: LoadPlanner = None
) -> None:
"""
Load a distributed state_dict in SPMD style.
@ -150,25 +76,34 @@ def load_state_dict(
has an individual GPU, via ``torch.cuda.set_device()``
"""
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
if planner is None:
planner = DefaultLoadPlanner()
def load_model():
def local_step():
assert planner is not None
metadata = storage_reader.read_metadata()
bytes_read_requests, tensor_read_requests = _reshard_and_prepare_read_request(
state_dict=state_dict, metadata_from_storage=metadata
)
bytes_futures = storage_reader.read_bytes(bytes_read_requests)
tensor_futures = storage_reader.read_tensors(tensor_read_requests)
planner.init(state_dict, metadata, distW.is_coordinator)
storage_reader.init(metadata, distW.is_coordinator)
bytes_futures.wait()
local_plan = planner.create_local_plan()
local_plan = storage_reader.prepare_local_plan(local_plan)
return local_plan
# Addtional steps are required to convert the bytes to its original type
# Note that this is NOT inplace,
# it creating a new object and replace what's in the state dict
for req in bytes_read_requests:
# Ensure the BytesIO is rewound
req.bytes.seek(0)
state_dict[req.fqn] = torch.load(req.bytes)
def global_step(all_local_plans):
assert planner is not None
all_local_plans = planner.create_global_plan(all_local_plans)
all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
return all_local_plans
tensor_futures.wait()
central_plan = distW.reduce_scatter("plan", local_step, global_step)
distW.all_gather("checkpoint read", load_model)
def read_data():
assert planner is not None
final_local_plan = planner.finish_plan(central_plan)
all_reads = storage_reader.read_data(final_local_plan, planner)
all_reads.wait()
return None
_ = distW.all_gather("read", read_data)

View File

@ -1,104 +1,29 @@
import io
from typing import Any, Dict, List, Tuple, Optional, Union
import torch
from typing import Optional
import torch.distributed as dist
from torch import Tensor
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from .planner import SavePlanner
from .default_planner import DefaultSavePlanner
from .metadata import (
Metadata,
BytesWriteRequest,
TensorWriteRequest,
)
from .resharding import (
_prepare_sharded_tensor_write,
_prepare_tensor_write,
_prepare_bytes_write
)
from .storage import (
StorageWriter,
)
from .metadata import (
Metadata,
STATE_DICT_TYPE
)
from .utils import _DistWrapper
# -------------- private functions --------------
def _prepare(
state_dict: Dict[str, Any],
write_replicated_data: bool,
) -> Tuple[Metadata, List[BytesWriteRequest], List[TensorWriteRequest]]:
"""
Build the serialization plan for a given state_dict
Args:
state_dict: The instance to plan for.
Returns:
A tuple with the following values:
metadata: Metadata
The storage metadata describing Tensor and ShardedTensors
instances found in `state_dict`. See `Metadata` for the schema.
size_for_storage_keys: Dict[str, int]
Key is the storage key name, value is the associated size
It can used to pre allocate the storage for parallel and non sequential writes.
bytes_write_requests: List[BytesWriteRequest]
List of ByteIO write requests that should be performed by the writer.
tensor_write_requests: List[TensorWriteRequest]
List of Tensor write requests that should be performed by the writer.
"""
metadata = Metadata(state_dict_metadata={})
tensor_write_requests: List[TensorWriteRequest] = []
bytes_write_requests: List[BytesWriteRequest] = []
storage_key_to_fqn: Dict[str, str] = {}
storage_md = {}
for fqn, obj in state_dict.items():
if isinstance(obj, ShardedTensor):
st_write_reqs, st_md, storage_data = _prepare_sharded_tensor_write(fqn, obj, fqn, storage_key_to_fqn)
tensor_write_requests += st_write_reqs
metadata.state_dict_metadata[fqn] = st_md
storage_md.update(storage_data)
elif isinstance(obj, Tensor):
write_reqs, tensor_md, storage_data = _prepare_tensor_write(obj, fqn, storage_key_to_fqn)
if write_replicated_data:
tensor_write_requests += write_reqs
metadata.state_dict_metadata[fqn] = tensor_md
storage_md.update(storage_data)
else:
bytes_io = io.BytesIO()
# This produces incomplete MD for rank > 0 since we won't populate bytes_io.
# This is ok since only rank == 0 uses this data
if write_replicated_data:
torch.save(obj, bytes_io)
byte_write_reqs, bytes_md, storage_data = _prepare_bytes_write(bytes_io, fqn, storage_key_to_fqn)
if write_replicated_data:
bytes_write_requests += byte_write_reqs
metadata.state_dict_metadata[fqn] = bytes_md
storage_md.update(storage_data)
metadata.storage_data = storage_md
return (metadata, bytes_write_requests, tensor_write_requests)
def save_state_dict(
state_dict: Dict[str, Any],
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False
) -> None:
no_dist: bool = False,
planner: SavePlanner = None
) -> Metadata:
"""
Save a distributed model in SPMD style.
@ -149,29 +74,41 @@ def save_state_dict(
has an individual GPU, via ``torch.cuda.set_device()``
"""
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None
distW.broadcast("prepare", storage_writer.prepare)
metadata = None
global_metatadata = None
def write_step():
nonlocal metadata
(
metadata,
bytes_write_requests,
tensor_write_requests,
) = _prepare(state_dict, distW.is_coordinator)
def local_step():
assert planner is not None
planner.init(state_dict, distW.is_coordinator)
storage_writer.init(distW.is_coordinator)
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
combined_writes: List[Union[TensorWriteRequest, BytesWriteRequest]] = []
combined_writes.extend(tensor_write_requests)
combined_writes.extend(bytes_write_requests)
def global_step(all_local_plans):
nonlocal global_metatadata
storage_writer.prepare_storage(combined_writes)
bytes_futures = storage_writer.write_bytes(bytes_write_requests)
tensor_futures = storage_writer.write_tensors(tensor_write_requests)
torch.futures.wait_all([bytes_futures, tensor_futures])
assert planner is not None
all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
def finish_checkpoint(_):
assert metadata is not None
storage_writer.finish(metadata=metadata)
central_plan = distW.reduce_scatter("plan", local_step, global_step)
distW.all_reduce("checkpoitn write", write_step, finish_checkpoint)
def write_data():
assert planner is not None
final_local_plan = planner.finish_plan(central_plan)
all_writes = storage_writer.write_data(final_local_plan, planner)
all_writes.wait()
return all_writes.value()
def finish_checkpoint(all_results):
assert global_metatadata is not None
storage_writer.finish(metadata=global_metatadata, results=all_results)
return global_metatadata
return distW.all_reduce("write", write_data, finish_checkpoint)

View File

@ -1,188 +1,227 @@
import abc
from typing import List, Union
from dataclasses import dataclass
from typing import List, Any
from torch.futures import Future
from .metadata import (
BytesReadRequest,
BytesWriteRequest,
Metadata,
TensorReadRequest,
TensorWriteRequest,
MetadataIndex,
)
from .planner import (
LoadPlan,
SavePlan,
SavePlanner,
LoadPlanner,
)
@dataclass(frozen=True)
class WriteResult:
index: MetadataIndex
size_in_bytes: int
storage_data: Any
class StorageWriter(abc.ABC):
"""
Interface used by ``save_state_dict`` to write to storage.
A subclass should expect the following sequence of calls by ``save_state_dict``
1) (called once globally) prepare()
2) prepare_storage() with the writes that will be used with (3) and (4).
3) write_bytes
4) write_tensors.
5) Wait for (2) and (3) futures. If either fail, abort checkpoint.
6) (called once globally) finish().
There's a single process that executes methods that are called once globally.
The writes from (3) and (4) are initiated before any waiting is done.
The last call to finish() has the semantics of commiting the checkpoint.
One StorageWriter instance acts as both the coordinator and the follower
in a distributed checkpoint. As part of initialization, each instance
is told its role.
A subclass should expect the following sequence of calls.
1) (all ranks) init()
2) (all ranks) prepare_local_plan()
3) (coordinator) prepare_global_plan()
4) (all ranks) write_data()
5) (coordinator) finish()
"""
@abc.abstractmethod
def prepare(self) -> None:
"""
Initialize storage to receive the checkpoint.
This method is called once globally per checkpoint before any other method.
This is in contrast to ``prepare_storage`` which is called on each process
in parallel.
Returns:
Future to signal intialization is complete.
"""
pass
@abc.abstractmethod
def write_bytes(self, requests: List[BytesWriteRequest]) -> Future[None]:
def init(self, is_coordinator: bool) -> None:
"""
Initiate writes for all requests in `requests`.
Writing can happen asynchronously and/or concurrently. A blocking
implementation is valid.
Initialize this instance.
Args:
requests (List[BytesWriteRequest]): A list of requests to write
Returns:
A future that completes once all writes have finished.
is_coordinator (bool): Whether this instance is reponsible for coordinating
the checkpoint.
"""
pass
@abc.abstractmethod
def write_tensors(self, requests: List[TensorWriteRequest]) -> Future[None]:
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
"""
Initiate writes for all requests in `requests`.
Perform storage-specific local planning.
Writing can happen asynchronously and/or concurrently. A blocking
implementation is valid.
Implementors are responsible for any device to host transfers required
to copy.
While this method can produce a completely different plan, the recomended
way is to store storage specific data in SavePlan::storage_data.
Args:
requests (List[TensorWriteRequest]): A list of requests to write
plan (SavePlan): The local plan from the ``SavePlanner`` in use.
Returns:
A future that completes once all writes have finished.
A transformed ``SavePlan`` after storage local planning
"""
pass
@abc.abstractmethod
def finish(self, metadata: Metadata) -> None:
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
"""
Writes the metadata and marks the current checkpoint as sucessfull.
Perform centralized planning of storage.
This method is called once globally after all data was writen
and is used to write its metadata and commit the checkpoint.
This method is only called on the coordinator instance.
The `metadata` object includes a global view of the checkpoint
and, while writing it is optional, it must be recoverable by the
StorageReader implementation.
While this method can produce a completely different plan, the prefered
way is to store storage specific data in SavePlan::storage_data.
The actual format/schema used for serializing `metadata` is
considered and implementation detail.
Args:
plans: A list of ``SavePlan`` instances, one for each rank.
Returns:
A list of transformed ``SavePlan`` after storage global planning
"""
pass
@abc.abstractmethod
def write_data(
self,
plan: SavePlan,
planner: SavePlanner
) -> Future[List[WriteResult]]:
"""
Write all items from ``plan`` using ``planner`` to resolve the data.
A subclass should call ``SavePlanner::resolve_data`` on each item
from the plan to get access to the underlying object to write.
Subclasses should lazily call `resolve_data` as it can allocate memory.
In case of tensors, make following assuptions:
- They might be on any device, including not matching the one on ``WriteItem::tensor_data``
- They might be views or not contiguous. Only the projection needs to be saved.
Args:
plan (SavePlan): The save plan to execute.
planner (SavePlanner): Planner object to be used to resolve items to data.
Returns:
A future that completes to a list of WriteResult
"""
pass
@abc.abstractmethod
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
"""
Writes the metadata and marks the current checkpoint as sucessful.
The actual format/schema used for serializing `metadata` is an
implemetation detail. The only requirement is that it's recoverable
in to the same object graph.
Args:
metadata (Metadata): metadata for the new checkpoint
results: A list of WriteResults from all ranks.
Returns:
None
"""
pass
def prepare_storage(self, storage_writes: List[Union[TensorWriteRequest, BytesWriteRequest]]) -> None:
"""
Prepare the underlying storage for upcoming writes.
This is an optional override intended for advanced scenarios where
a storage layer needs wants to do some work ahead of the writing itself.
This method is called on each process in parallel before any writes are performed.
The default implementation does nothing.
Args:
storage_writes (List[Union[TensorWriteRequest, BytesWriteRequest]]): A list of
all writes that will be submited.
Returns:
None
"""
pass
class StorageReader(abc.ABC):
"""
Interface used by ``load_state_dict`` to read from storage.
One StorageReader instance acts as both the coordinator and the follower
in a distributed checkpoint. As part of initialization, each instance
is told its role.
A subclass should expected the following sequence of calls by ``load_state_dict``:
1) read_metadata() - on all ranks
2) read_bytes
3) read_tensors
The reads from (2) and (3) are initiated before any waiting is done.
Implementors must ensure host/device synchronization as part of
completion of both read requests.
1) (all ranks) read_metadata()
2) (all ranks) init
3) (all ranks) prepare_local_plan
4) (coordinator) prepare_global_plan
5) (all ranks) read_data
"""
@abc.abstractmethod
def read_bytes(self, requests: List[BytesReadRequest]) -> Future[None]:
"""
Initiate read for all requests in `requests`.
Reading happen asynchronously and/or concurrently. A blocking
implementation is valid.
Args:
requests (List[BytesReadRequest]): A list of requests to read.
Return:
A future that completes once all read have finished.
"""
pass
@abc.abstractmethod
def read_tensors(self, requests: List[TensorReadRequest]) -> Future[None]:
"""
Initiate read for all requests in `requests`.
Reading happen asynchronously and/or concurrently. A blocking
implementation is valid.
Implementors must not assume that the original device
at write time will be the same at read time.
If an implementation uses asynchronous copies to device, it must
ensure proper synchronization W.R.T. the returned future.
Args:
requests (List[BytesReadRequest]): A list of requests to read.
Returns:
A future that completes once all read have finished.
"""
pass
@abc.abstractmethod
def read_metadata(self) -> Metadata:
"""
Reads the checkpoint metadata.
Returnss:
Returns:
The metatada object associated with the checkpoint being loaded.
"""
pass
@abc.abstractmethod
def init(self, metadata: Metadata, is_coordinator: bool) -> None:
"""
Initialize this instance.
Args:
metadata (Metadata): The metadata schema to use.
is_coordinator (bool): Whether this instance is reponsible for coordinating
the checkpoint.
"""
pass
@abc.abstractmethod
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
"""
Perform storage-specific local planning.
While this method can produce a completely different plan, the recomended
way is to store storage specific data in LoadPlan::storage_data.
Args:
plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
Returns:
A transformed ``LoadPlan`` after storage local planning
"""
pass
@abc.abstractmethod
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
"""
Perform centralized planning of storage loading.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the prefered
way is to store storage specific data in LoadPlan::storage_data.
Args:
plans: A list of ``LoadPlan`` instances, one for each rank.
Returns:
A list of transformed ``LoadPlan`` after storage global planning
"""
pass
@abc.abstractmethod
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
"""
Reads all items from ``plan`` using ``planner`` to resolve the data.
A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
object into the right place.
A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
tensors that in should load data into.
It's the StorageLayer responsibility to properly schedule any cross device copies
required.
Args:
plan (LoadPlan): The local plan to execute on
planner (LoadPlanner): The planner object to use to resolve items.
Returns:
A future that completes once all reads are finished.
"""
pass

View File

@ -265,15 +265,22 @@ def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
return shard
raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
if isinstance(tensor, ShardedTensor):
return _find_shard(tensor, index).tensor
if index.offset is not None:
# special case looking up a tensor by origin
if index.offset == torch.Size([0] * len(tensor.size())):
return tensor
raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'")
return tensor
def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
if index.fqn not in state_dict:
raise ValueError(f"Could not find FQN: '{index.fqn}'")
obj = state_dict[index.fqn]
if isinstance(obj, ShardedTensor):
return _find_shard(obj, index).tensor
if index.offset is not None:
# special case looking up a tensor by origin
if isinstance(obj, torch.Tensor) and index.offset == torch.Size([0] * len(obj.size())):
return obj
if isinstance(obj, torch.Tensor):
return find_tensor_shard(obj, index)
elif index.offset is not None:
raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'")
return obj