mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[checkpoint] Adopt Planner interface across the board. (#83781)
Change StorageReader and StorageWriter to follow the new SavePlanner / LoadPlanner design. Add optional planner param to load_state_dict and save_state_dict and implement the new protocol. This includes a small rework of FileSystem layer to support single file per rank and making fsync optional to match torch.save behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83781 Approved by: https://github.com/wanchaol, https://github.com/fduwjj
This commit is contained in:
parent
fbf5a3f9f4
commit
f66be71d77
|
|
@ -1,8 +1,9 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import random
|
||||
import sys
|
||||
from typing import Optional, List, Union, cast
|
||||
from typing import Optional, List, cast
|
||||
from torch.distributed._shard.checkpoint.storage import WriteResult
|
||||
|
||||
from torch.distributed._shard.checkpoint import (
|
||||
StorageReader,
|
||||
StorageWriter,
|
||||
|
|
@ -16,26 +17,24 @@ import torch.distributed as dist
|
|||
import torch.nn
|
||||
import torch.futures
|
||||
from torch.futures import Future
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
from torch.distributed._shard.checkpoint.resharding import (
|
||||
_prepare_sharded_tensor_write,
|
||||
_create_storage_key
|
||||
)
|
||||
|
||||
from torch.distributed._shard import sharded_tensor
|
||||
|
||||
from torch.distributed._shard.checkpoint.state_dict_saver import (
|
||||
_prepare,
|
||||
from torch.distributed._shard.checkpoint.default_planner import (
|
||||
_create_default_local_metadata,
|
||||
)
|
||||
|
||||
from torch.distributed._shard.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
Metadata,
|
||||
BytesReadRequest,
|
||||
BytesWriteRequest,
|
||||
MetadataIndex,
|
||||
TensorReadRequest,
|
||||
TensorWriteRequest,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
|
||||
from torch.distributed._shard.checkpoint.planner import (
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
LoadPlan,
|
||||
LoadPlanner,
|
||||
)
|
||||
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
|
|
@ -89,17 +88,6 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
|
|||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
def gen_metadata(self) -> Metadata:
|
||||
module = TestModule()
|
||||
# compute the default saved metadata (must pass include_non_replicated_tensors or we'll get incomplete MD)
|
||||
metadata, _, _ = _prepare(module.state_dict(), True)
|
||||
|
||||
# _prepare only produc
|
||||
metadata = [metadata]
|
||||
dist.broadcast_object_list(metadata)
|
||||
|
||||
return metadata[0]
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl()
|
||||
|
|
@ -114,15 +102,15 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
|
|||
st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
|
||||
mapping = {}
|
||||
|
||||
(_, md, storage_md) = _prepare_sharded_tensor_write("fqn", st, "tensor", mapping)
|
||||
md = _create_default_local_metadata({"st": st})
|
||||
|
||||
self.assertEqual(1, len(storage_md))
|
||||
self.assertEqual(1, len(mapping))
|
||||
st_md = md.state_dict_metadata["st"]
|
||||
self.assertEqual(1, len(st_md.chunks))
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl()
|
||||
def test_storage_key_mapping(self) -> None:
|
||||
def test_default_metadata(self) -> None:
|
||||
device = f"cuda:{dist.get_rank()}"
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
|
|
@ -138,48 +126,23 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
|
|||
'bytes': [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
metadata, bytes_reqs, tensor_reqs = _prepare(state_dict, write_replicated_data=self.rank == 0)
|
||||
|
||||
if self.rank == 0:
|
||||
self.assertEqual(1, len(bytes_reqs))
|
||||
self.assertEqual(2, len(tensor_reqs))
|
||||
|
||||
self.assertTrue('bytes' in metadata.state_dict_metadata)
|
||||
self.assertTrue(MetadataIndex('bytes') in metadata.storage_data)
|
||||
|
||||
# tensor ordering is unspecified
|
||||
if len(tensor_reqs[0].tensor.size()) == 1:
|
||||
replicated = tensor_reqs[0]
|
||||
shard = tensor_reqs[1]
|
||||
else:
|
||||
replicated = tensor_reqs[1]
|
||||
shard = tensor_reqs[0]
|
||||
|
||||
self.assertTrue('replicated' in metadata.state_dict_metadata)
|
||||
storage_key = MetadataIndex('replicated', torch.Size([0]))
|
||||
self.assertTrue(storage_key in metadata.storage_data)
|
||||
self.assertTrue(metadata.storage_data[storage_key], replicated.storage_key)
|
||||
else:
|
||||
self.assertEqual(0, len(bytes_reqs))
|
||||
self.assertEqual(1, len(tensor_reqs))
|
||||
shard = tensor_reqs[0]
|
||||
local_shard = state_dict["sharded"].local_shards()[0]
|
||||
|
||||
self.assertTrue('sharded' in metadata.state_dict_metadata)
|
||||
storage_key = MetadataIndex('sharded', torch.Size(local_shard.metadata.shard_offsets))
|
||||
self.assertTrue(storage_key in metadata.storage_data)
|
||||
self.assertTrue(metadata.storage_data[storage_key], shard.storage_key)
|
||||
|
||||
|
||||
class TestStorageKeys(TestCase):
|
||||
def test_create_key_handles_collision(self):
|
||||
keys = {}
|
||||
key0 = _create_storage_key(keys, "foo")
|
||||
key1 = _create_storage_key(keys, "foo")
|
||||
self.assertNotEqual(key0, key1)
|
||||
|
||||
metadata = _create_default_local_metadata(state_dict)
|
||||
self.assertTrue('bytes' in metadata.state_dict_metadata)
|
||||
self.assertIsInstance(metadata.state_dict_metadata['bytes'], BytesStorageMetadata)
|
||||
|
||||
self.assertTrue('replicated' in metadata.state_dict_metadata)
|
||||
self.assertIsInstance(metadata.state_dict_metadata['replicated'], TensorStorageMetadata)
|
||||
md = metadata.state_dict_metadata['replicated']
|
||||
self.assertEqual(md.size, state_dict['replicated'].size())
|
||||
self.assertEqual(md.properties.dtype, torch.float32)
|
||||
self.assertEqual(1, len(md.chunks))
|
||||
|
||||
self.assertTrue('sharded' in metadata.state_dict_metadata)
|
||||
self.assertIsInstance(metadata.state_dict_metadata['sharded'], TensorStorageMetadata)
|
||||
md = metadata.state_dict_metadata['sharded']
|
||||
self.assertEqual(md.properties.dtype, torch.float32)
|
||||
self.assertEqual(md.size, state_dict['sharded'].size())
|
||||
self.assertEqual(2, len(md.chunks))
|
||||
|
||||
class TestStorageBase:
|
||||
def __init__(
|
||||
|
|
@ -197,16 +160,15 @@ class TestStorageBase:
|
|||
if ranks is not None and self.rank in ranks:
|
||||
raise ValueError(f"rank fail {self.rank} for {name}")
|
||||
|
||||
def _fail_rank_async(self, name):
|
||||
def _fail_rank_async(self, name, result=None):
|
||||
ranks = self._get_ranks(name)
|
||||
fut = Future()
|
||||
if ranks is not None and self.rank in ranks:
|
||||
fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
|
||||
else:
|
||||
fut.set_result(None)
|
||||
fut.set_result(result)
|
||||
return fut
|
||||
|
||||
|
||||
class FaultyStorageWriter(TestStorageBase, StorageWriter):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -214,22 +176,28 @@ class FaultyStorageWriter(TestStorageBase, StorageWriter):
|
|||
):
|
||||
super(FaultyStorageWriter, self).__init__(fail_conf)
|
||||
|
||||
def prepare(self) -> None:
|
||||
self._fail_rank("fail_prepare")
|
||||
def init(self, is_coordinator: bool) -> None:
|
||||
self._fail_rank("fail_init")
|
||||
|
||||
def write_bytes(self, requests: List[BytesWriteRequest]) -> Future[None]:
|
||||
self._fail_rank("fail_write_bytes_on_ranks")
|
||||
return self._fail_rank_async("fail_write_bytes_on_ranks_async")
|
||||
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
|
||||
self._fail_rank("fail_prepare_local_plan")
|
||||
return plan
|
||||
|
||||
def write_tensors(self, requests: List[TensorWriteRequest]) -> Future[None]:
|
||||
self._fail_rank("fail_write_tensors_on_ranks")
|
||||
return self._fail_rank_async("fail_write_tensors_on_ranks_async")
|
||||
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
|
||||
self._fail_rank("fail_prepare_global_plan")
|
||||
return plans
|
||||
|
||||
def finish(self, metadata: Metadata) -> None:
|
||||
def write_data(
|
||||
self,
|
||||
plan: SavePlan,
|
||||
planner: SavePlanner
|
||||
) -> Future[List[WriteResult]]:
|
||||
self._fail_rank("fail_write_data")
|
||||
return self._fail_rank_async("fail_write_data_async", [])
|
||||
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
self._fail_rank("fail_finish")
|
||||
|
||||
def prepare_storage(self, storage_writes: List[Union[TensorWriteRequest, BytesWriteRequest]]) -> None:
|
||||
self._fail_rank("fail_prepare_storage")
|
||||
|
||||
class FaultyStorageReader(TestStorageBase, StorageReader):
|
||||
def __init__(
|
||||
|
|
@ -240,22 +208,24 @@ class FaultyStorageReader(TestStorageBase, StorageReader):
|
|||
super(FaultyStorageReader, self).__init__(fail_conf)
|
||||
self.metadata = metadata
|
||||
|
||||
def read_bytes(self, requests: List[BytesReadRequest]) -> Future[None]:
|
||||
self._fail_rank("fail_read_bytes")
|
||||
bad_ranks = self._get_ranks("fail_deser_bytes")
|
||||
for r in requests:
|
||||
if bad_ranks is not None and self.rank in bad_ranks:
|
||||
# this is not "guaranteed" to fail, but hard to beat
|
||||
rand = random.Random(1237)
|
||||
r.bytes.write(rand.randbytes(32))
|
||||
else:
|
||||
torch.save([1, 2, 3], r.bytes)
|
||||
def init(self, metadata: Metadata, is_coordinator: bool) -> None:
|
||||
self._fail_rank("fail_init")
|
||||
|
||||
return self._fail_rank_async("fail_read_bytes_async")
|
||||
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
||||
self._fail_rank("fail_prepare_local_plan")
|
||||
return plan
|
||||
|
||||
def read_tensors(self, requests: List[TensorReadRequest]) -> Future[None]:
|
||||
self._fail_rank("fail_read_tensors")
|
||||
return self._fail_rank_async("fail_read_tensors_async")
|
||||
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
|
||||
self._fail_rank("fail_prepare_global_plan")
|
||||
return plans
|
||||
|
||||
def read_data(
|
||||
self,
|
||||
plan: LoadPlan,
|
||||
planner: LoadPlanner
|
||||
) -> Future[None]:
|
||||
self._fail_rank("fail_read_data")
|
||||
return self._fail_rank_async("fail_read_data_async")
|
||||
|
||||
def read_metadata(self) -> Metadata:
|
||||
self._fail_rank("fail_read_metadata")
|
||||
|
|
@ -282,6 +252,18 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
|||
|
||||
save_state_dict(state_dict, FaultyStorageWriter({}))
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl()
|
||||
def test_dummy_reader_works(self) -> None:
|
||||
state_dict = {
|
||||
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
|
||||
'replicated': torch.rand(10, 10),
|
||||
'bytes': [1, 2, 3, 4]
|
||||
}
|
||||
metadata = _create_default_local_metadata(state_dict)
|
||||
|
||||
load_state_dict(state_dict, FaultyStorageReader(metadata, {}))
|
||||
|
||||
def _test_dist_failure(self, callback, kwargs):
|
||||
bad_ranks = list(kwargs.values())[0] if len(kwargs) > 0 else []
|
||||
|
|
@ -318,10 +300,9 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
|||
|
||||
def _test_load(self, state_dict, coordinator=0, **kwargs):
|
||||
no_dist = not dist.is_initialized()
|
||||
write_replicated = dist.is_initialized() and dist.get_rank() == coordinator
|
||||
|
||||
def _load():
|
||||
metadata, _, _ = _prepare(state_dict, write_replicated)
|
||||
metadata = _create_default_local_metadata(state_dict)
|
||||
load_state_dict(
|
||||
state_dict,
|
||||
storage_reader=FaultyStorageReader(metadata, kwargs),
|
||||
|
|
@ -341,21 +322,17 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
|||
'bytes': [1, 2, 3, 4]
|
||||
}
|
||||
|
||||
self._test_save(state_dict, fail_prepare=[0])
|
||||
self._test_save(state_dict, fail_init=[0])
|
||||
self._test_save(state_dict, fail_finish=[0])
|
||||
self._test_save(state_dict, fail_prepare_global_plan=[0])
|
||||
|
||||
self._test_save(state_dict, fail_prepare_storage=[0])
|
||||
self._test_save(state_dict, fail_write_tensors_on_ranks=[1])
|
||||
self._test_save(state_dict, fail_write_tensors_on_ranks_async=[2])
|
||||
self._test_save(state_dict, fail_write_bytes_on_ranks=[3])
|
||||
self._test_save(state_dict, fail_write_bytes_on_ranks_async=[1])
|
||||
self._test_save(state_dict, fail_prepare_local_plan=[0])
|
||||
self._test_save(state_dict, fail_write_data=[2])
|
||||
self._test_save(state_dict, fail_write_data_async=[3])
|
||||
|
||||
self._test_save(state_dict, fail_write_tensors_on_ranks_async=[1, 3])
|
||||
|
||||
self._test_save(state_dict, coordinator=1, fail_prepare=[1])
|
||||
self._test_save(state_dict, coordinator=1, fail_init=[1])
|
||||
self._test_save(state_dict, coordinator=1, fail_finish=[1])
|
||||
|
||||
|
||||
def test_save_error_handling_no_dist(self) -> None:
|
||||
state_dict = {
|
||||
'replicated': torch.rand(10, 10),
|
||||
|
|
@ -364,14 +341,13 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
|||
|
||||
self.assertFalse(dist.is_initialized())
|
||||
|
||||
self._test_save(state_dict, fail_prepare=[0])
|
||||
self._test_save(state_dict, fail_init=[0])
|
||||
self._test_save(state_dict, fail_finish=[0])
|
||||
self._test_save(state_dict, fail_prepare_global_plan=[0])
|
||||
|
||||
self._test_save(state_dict, fail_prepare_storage=[0])
|
||||
self._test_save(state_dict, fail_write_tensors_on_ranks=[0])
|
||||
self._test_save(state_dict, fail_write_tensors_on_ranks_async=[0])
|
||||
self._test_save(state_dict, fail_write_bytes_on_ranks=[0])
|
||||
self._test_save(state_dict, fail_write_bytes_on_ranks_async=[0])
|
||||
self._test_save(state_dict, fail_prepare_local_plan=[0])
|
||||
self._test_save(state_dict, fail_write_data=[0])
|
||||
self._test_save(state_dict, fail_write_data_async=[0])
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
|
|
@ -384,17 +360,18 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
|||
}
|
||||
|
||||
self._test_load(state_dict)
|
||||
self._test_load(state_dict, fail_init=[0])
|
||||
self._test_load(state_dict, fail_prepare_global_plan=[0])
|
||||
self._test_load(state_dict, fail_read_metadata=[0])
|
||||
self._test_load(state_dict, fail_read_bytes=[1])
|
||||
self._test_load(state_dict, fail_read_bytes_async=[2])
|
||||
self._test_load(state_dict, fail_read_tensors=[3])
|
||||
self._test_load(state_dict, fail_read_tensors_async=[1])
|
||||
# We don't want to depend on the actual exception raised by pickle
|
||||
self._test_load(state_dict, fail_deser_bytes=[2], ignore_exception_type=True)
|
||||
self._test_load(state_dict, fail_prepare_local_plan=[1])
|
||||
self._test_load(state_dict, fail_read_data=[3])
|
||||
self._test_load(state_dict, fail_read_data_async=[1])
|
||||
|
||||
self._test_load(state_dict, coordinator=3, fail_init=[0])
|
||||
self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
|
||||
self._test_load(state_dict, coordinator=2, fail_read_bytes=[0])
|
||||
self._test_load(state_dict, coordinator=3, fail_read_tensors_async=[2])
|
||||
self._test_load(state_dict, coordinator=2, fail_read_data=[0])
|
||||
self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
|
||||
self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])
|
||||
|
||||
|
||||
def test_load_error_handling_no_dist(self) -> None:
|
||||
|
|
@ -403,11 +380,12 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
|||
'bytes': [1, 2, 3, 4]
|
||||
}
|
||||
self._test_load(state_dict)
|
||||
self._test_load(state_dict, fail_init=[0])
|
||||
self._test_load(state_dict, fail_read_metadata=[0])
|
||||
self._test_load(state_dict, fail_read_bytes=[0])
|
||||
self._test_load(state_dict, fail_read_bytes_async=[0])
|
||||
self._test_load(state_dict, fail_read_tensors=[0])
|
||||
self._test_load(state_dict, fail_read_tensors_async=[0])
|
||||
self._test_load(state_dict, fail_deser_bytes=[0], ignore_exception_type=True)
|
||||
self._test_load(state_dict, fail_prepare_local_plan=[0])
|
||||
self._test_load(state_dict, fail_prepare_global_plan=[0])
|
||||
self._test_load(state_dict, fail_read_data=[0])
|
||||
self._test_load(state_dict, fail_read_data_async=[0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
|
@ -119,6 +118,23 @@ class TestDistributedStateDictSaveLoad(TestCase):
|
|||
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
|
||||
fs_writer = FileSystemWriter(path=path, single_file_per_rank=True)
|
||||
save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer, no_dist=True)
|
||||
|
||||
state_dict_to_load_to = MyTestModule().state_dict()
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
# Load from file without any resharding
|
||||
fs_reader = FileSystemReader(path=path)
|
||||
load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader, no_dist=True)
|
||||
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ class TestSavePlan(TestCase):
|
|||
|
||||
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
||||
final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
|
||||
|
||||
# The default global plan updates all indexes to include hints
|
||||
for new_plan, old_plan in zip(final_plans, all_plans):
|
||||
for new_item, old_item in zip(new_plan.items, old_plan.items):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,8 @@
|
|||
from .metadata import (
|
||||
BytesReadRequest,
|
||||
BytesWriteRequest,
|
||||
TensorStorageMetadata,
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
TensorReadRequest,
|
||||
TensorWriteRequest,
|
||||
)
|
||||
from .state_dict_loader import load_state_dict
|
||||
from .state_dict_saver import save_state_dict
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
import os
|
||||
import operator
|
||||
import dataclasses
|
||||
import io
|
||||
import pickle
|
||||
from typing import List, Optional, Union, cast
|
||||
from typing import List, Union, Dict, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -9,16 +11,100 @@ from torch.futures import Future
|
|||
from pathlib import Path
|
||||
|
||||
from .metadata import (
|
||||
BytesReadRequest,
|
||||
BytesWriteRequest,
|
||||
Metadata,
|
||||
TensorReadRequest,
|
||||
TensorWriteRequest,
|
||||
MetadataIndex,
|
||||
)
|
||||
from .storage import StorageReader, StorageWriter
|
||||
from .storage import (
|
||||
StorageReader,
|
||||
StorageWriter,
|
||||
WriteResult,
|
||||
)
|
||||
|
||||
from .planner import (
|
||||
LoadItemType,
|
||||
LoadPlanner,
|
||||
LoadPlan,
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
ReadItem,
|
||||
WriteItem,
|
||||
WriteItemType,
|
||||
)
|
||||
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StorageInfo:
|
||||
"""
|
||||
This is the per entry storage info
|
||||
"""
|
||||
relative_path: str
|
||||
offset: int
|
||||
length: int
|
||||
|
||||
@dataclass
|
||||
class _StoragePrefix:
|
||||
prefix: str
|
||||
|
||||
DEFAULT_SUFIX = ".distcp"
|
||||
|
||||
def _trim(tensor: torch.Tensor) -> torch.Tensor:
|
||||
tensor = tensor.detach().cpu()
|
||||
if tensor.storage().size() != tensor.numel():
|
||||
tensor = tensor.clone()
|
||||
return tensor
|
||||
|
||||
def _result_from_write_item(item: WriteItem, size_in_bytes, storage_data) -> WriteResult:
|
||||
return WriteResult(
|
||||
index=item.index,
|
||||
size_in_bytes=size_in_bytes,
|
||||
storage_data=storage_data)
|
||||
|
||||
def _write_item(stream, data, write_item, storage_key):
|
||||
offset = stream.tell()
|
||||
|
||||
if write_item.type == WriteItemType.BYTE_IO:
|
||||
assert isinstance(data, io.BytesIO)
|
||||
stream.write(data.getbuffer())
|
||||
else:
|
||||
assert isinstance(data, torch.Tensor)
|
||||
assert data.device == torch.device("cpu")
|
||||
torch.save(data, stream)
|
||||
length = stream.tell() - offset
|
||||
|
||||
return _result_from_write_item(
|
||||
write_item,
|
||||
length,
|
||||
_StorageInfo(storage_key, offset, length)
|
||||
)
|
||||
|
||||
def _write_files_from_queue(
|
||||
file_queue: List,
|
||||
planner: SavePlanner,
|
||||
use_fsync: bool,
|
||||
):
|
||||
write_results = []
|
||||
|
||||
for file_path, file_name, write_items in file_queue:
|
||||
tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
|
||||
bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
|
||||
|
||||
with open(file_path, "wb") as stream:
|
||||
for write_item in bytes_w:
|
||||
data = planner.resolve_data(write_item)
|
||||
write_results.append(_write_item(stream, data, write_item, file_name))
|
||||
|
||||
for write_item in tensor_w:
|
||||
tensor = _trim(cast(torch.Tensor, planner.resolve_data(write_item)))
|
||||
assert not tensor.is_cuda
|
||||
write_results.append(_write_item(stream, tensor, write_item, file_name))
|
||||
|
||||
if use_fsync:
|
||||
os.fsync(stream.fileno())
|
||||
|
||||
return write_results
|
||||
|
||||
class FileSystemWriter(StorageWriter):
|
||||
"""
|
||||
Basic implementation of StorageWriter using file IO.
|
||||
|
|
@ -32,108 +118,152 @@ class FileSystemWriter(StorageWriter):
|
|||
a `.metadata` file with the serialized metadata.
|
||||
|
||||
"""
|
||||
def __init__(self, path: Union[str, os.PathLike]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
path: Union[str, os.PathLike],
|
||||
single_file_per_rank: bool = False,
|
||||
sync_files: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the writer pointing to `path`
|
||||
|
||||
Args:
|
||||
path: diretory where the checkpoint will be writen to.
|
||||
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
|
||||
sync_files: force files to be synced to permanent storage. Default to True.
|
||||
|
||||
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
|
||||
"""
|
||||
super().__init__()
|
||||
self.path = Path(path)
|
||||
self.single_file_per_rank = single_file_per_rank
|
||||
self.sync_files = sync_files
|
||||
|
||||
def write_bytes(self, requests: List[BytesWriteRequest]) -> Future[None]:
|
||||
for req in requests:
|
||||
with (self.path / req.storage_key).open("wb") as w:
|
||||
w.write(req.bytes.getbuffer())
|
||||
os.fsync(w.fileno())
|
||||
def init(self, is_coordinator: bool) -> None:
|
||||
pass
|
||||
|
||||
fut: Future[None] = Future()
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
|
||||
# There's no storage input in the local plan
|
||||
return plan
|
||||
|
||||
def write_tensors(self, requests: List[TensorWriteRequest]) -> Future[None]:
|
||||
for req in requests:
|
||||
# The following couple lines are simple implementation to get
|
||||
# things going.
|
||||
#
|
||||
# At load time, to enable resharding, we use (sub)view of the tensor.
|
||||
# Since the storage of the tensor might not be contiguous. we need to
|
||||
# preserve the original view, to calculate the correct sub view at load.
|
||||
#
|
||||
# `torch.save` saves both the view and storage, it is a good option
|
||||
# for unblocking. There are two drawbacks:
|
||||
# 1. `torch.save` is pickle based, and pickle is not known for its
|
||||
# compatibility, we should consider replacing it with a more
|
||||
# stable option.
|
||||
# 2. pickle is not streamable.
|
||||
with (self.path / req.storage_key).open("wb") as w:
|
||||
torch.save(req.tensor, w)
|
||||
os.fsync(w.fileno())
|
||||
|
||||
fut: Future[None] = Future()
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
|
||||
def prepare(self) -> None:
|
||||
def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def finish(self, metadata: Metadata) -> None:
|
||||
new_plans = [
|
||||
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan)
|
||||
]
|
||||
return new_plans
|
||||
|
||||
def write_data(
|
||||
self,
|
||||
plan: SavePlan,
|
||||
planner: SavePlanner,
|
||||
) -> Future[List[WriteResult]]:
|
||||
storage_plan: _StoragePrefix = plan.storage_data
|
||||
file_count = 0
|
||||
|
||||
def gen_file():
|
||||
nonlocal file_count
|
||||
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFIX}"
|
||||
file_count += 1
|
||||
return file_name
|
||||
|
||||
file_queue = []
|
||||
if self.single_file_per_rank:
|
||||
file_name = gen_file()
|
||||
file_queue.append((self.path / file_name, file_name, plan.items))
|
||||
else:
|
||||
for item in plan.items:
|
||||
file_name = gen_file()
|
||||
file_queue.append((self.path / file_name, file_name, [item]))
|
||||
|
||||
results = _write_files_from_queue(
|
||||
file_queue=file_queue,
|
||||
planner=planner,
|
||||
use_fsync=self.sync_files,
|
||||
)
|
||||
|
||||
fut: Future[List[WriteResult]] = Future()
|
||||
fut.set_result(results)
|
||||
return fut
|
||||
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
storage_md = dict()
|
||||
for wr_list in results:
|
||||
storage_md.update({
|
||||
wr.index: wr.storage_data for wr in wr_list
|
||||
})
|
||||
metadata.storage_data = storage_md
|
||||
with (self.path / ".metadata.tmp").open("wb") as metadata_file:
|
||||
pickle.dump(metadata, metadata_file)
|
||||
os.fsync(metadata_file.fileno())
|
||||
|
||||
(self.path / ".metadata.tmp").rename(self.path / ".metadata")
|
||||
|
||||
|
||||
class SlicedBufferedReader(io.BufferedReader):
|
||||
# TODO override read to handle (-1) correctly
|
||||
def __init__(self, base_stream: io.RawIOBase, offset: int, len: int):
|
||||
super().__init__(base_stream)
|
||||
self.offset = offset
|
||||
self.len = len
|
||||
self.seek(0)
|
||||
|
||||
def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
|
||||
if __whence == os.SEEK_SET:
|
||||
__offset = self.offset + __offset
|
||||
elif __whence == os.SEEK_END:
|
||||
__whence = os.SEEK_SET
|
||||
__offset = (self.offset + self.len) - __offset
|
||||
return super().seek(__offset, __whence)
|
||||
|
||||
def tell(self) -> int:
|
||||
return super().tell() - self.offset
|
||||
|
||||
class FileSystemReader(StorageReader):
|
||||
def __init__(self, path: Union[str, os.PathLike]) -> None:
|
||||
super().__init__()
|
||||
self.path = Path(path)
|
||||
self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
|
||||
|
||||
def read_tensors(self, requests: List[TensorReadRequest]) -> Future[None]:
|
||||
"""
|
||||
Very basic implementation that read from file system.
|
||||
"""
|
||||
# Sort the the requests by storage key and try to reuse the loaded tensors
|
||||
requests.sort(key=operator.attrgetter("storage_key"))
|
||||
def _slice_file(self, file, sinfo: _StorageInfo):
|
||||
return SlicedBufferedReader(
|
||||
io.FileIO(file.fileno(), closefd=False),
|
||||
sinfo.offset, sinfo.length
|
||||
)
|
||||
|
||||
cached_storage_key = None
|
||||
view_cached: Optional[Tensor] = None
|
||||
def read_data(
|
||||
self,
|
||||
plan: LoadPlan,
|
||||
planner: LoadPlanner
|
||||
) -> Future[None]:
|
||||
# group requests by file
|
||||
per_file: Dict[str, List[ReadItem]] = dict()
|
||||
for read_item in plan.items:
|
||||
item_md = self.storage_data[read_item.storage_index]
|
||||
path = item_md.relative_path
|
||||
per_file.setdefault(path, []).append(read_item)
|
||||
|
||||
for req in requests:
|
||||
if cached_storage_key != req.storage_key or \
|
||||
(view_cached is not None and view_cached.device != req.tensor.device):
|
||||
for relative_path, reqs in per_file.items():
|
||||
with (self.path / relative_path).open("rb") as file:
|
||||
# TODO sort by offset and cache the reading
|
||||
for req in reqs:
|
||||
item_md = self.storage_data[req.storage_index]
|
||||
file_slice = self._slice_file(file, item_md)
|
||||
if req.type == LoadItemType.BYTE_IO:
|
||||
bytes = io.BytesIO(file_slice.read(item_md.length))
|
||||
bytes.seek(0)
|
||||
planner.load_bytes(req, bytes)
|
||||
else:
|
||||
tensor = cast(Tensor, torch.load(file_slice, map_location="cpu"))
|
||||
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
|
||||
target_tensor = planner.resolve_tensor(req).detach()
|
||||
|
||||
with (self.path / req.storage_key).open("rb") as storage:
|
||||
view_cached = cast(Tensor, torch.load(storage, map_location=req.tensor.device))
|
||||
cached_storage_key = req.storage_key
|
||||
|
||||
view_to_copy: Tensor = cast(Tensor, view_cached)
|
||||
# FileSystemWrite writes the tensor as is during save.
|
||||
# During load time, we will load the Tensor (with it orignal view)
|
||||
# narrow it along all dimemsions, and copy_ it to the
|
||||
# target tensor, which will be the same size.
|
||||
view_to_copy = narrow_tensor_by_index(view_to_copy, req.offsets, req.lengths)
|
||||
|
||||
assert (
|
||||
view_to_copy.size() == req.tensor.size()
|
||||
), f"The {req.storage_key} src/dst size does not match."
|
||||
|
||||
|
||||
assert (
|
||||
view_to_copy.device == req.tensor.device
|
||||
), f"cannot load across devices {view_to_copy.device} vs {req.tensor.device}"
|
||||
|
||||
req.tensor.copy_(view_to_copy)
|
||||
|
||||
fut: Future = Future()
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
|
||||
def read_bytes(self, requests: List[BytesReadRequest]) -> Future[None]:
|
||||
for req in requests:
|
||||
with (self.path / req.storage_key).open("rb") as storage:
|
||||
req.bytes.write(storage.read())
|
||||
assert (
|
||||
target_tensor.size() == tensor.size()
|
||||
), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
target_tensor.copy_(tensor)
|
||||
planner.commit_tensor(req, target_tensor)
|
||||
|
||||
fut: Future = Future()
|
||||
fut.set_result(None)
|
||||
|
|
@ -143,3 +273,13 @@ class FileSystemReader(StorageReader):
|
|||
def read_metadata(self) -> Metadata:
|
||||
with (self.path / ".metadata").open("rb") as metadata_file:
|
||||
return pickle.load(metadata_file)
|
||||
|
||||
def init(self, metadata: Metadata, is_coordinator: bool) -> None:
|
||||
self.storage_data = metadata.storage_data
|
||||
assert self.storage_data is not None
|
||||
|
||||
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
|
||||
return global_plan
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import io
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Union, Optional, Sequence, Any
|
||||
from typing import Dict, List, Union, Optional, Sequence, Any
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
|
||||
@dataclass
|
||||
class ChunkStorageMetadata:
|
||||
|
|
@ -37,34 +36,6 @@ class Metadata:
|
|||
planner_data: Any = None
|
||||
storage_data: Any = None
|
||||
|
||||
@dataclass
|
||||
class BytesWriteRequest:
|
||||
bytes: io.BytesIO
|
||||
storage_key: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class BytesReadRequest:
|
||||
bytes: io.BytesIO
|
||||
storage_key: str
|
||||
fqn: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorWriteRequest:
|
||||
tensor: torch.Tensor
|
||||
storage_key: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorReadRequest:
|
||||
tensor: torch.Tensor
|
||||
storage_key: str
|
||||
# offset and length w.r.t. to the storage identified by ``storage_key``
|
||||
offsets: Tuple[int, ...]
|
||||
lengths: Tuple[int, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetadataIndex:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,61 +1,8 @@
|
|||
import hashlib
|
||||
import io
|
||||
from typing import List, Tuple, Dict
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
_check_shard_metadata_pair_overlap,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor.shard import Shard
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
|
||||
|
||||
from .metadata import (
|
||||
BytesWriteRequest,
|
||||
TensorReadRequest,
|
||||
TensorWriteRequest,
|
||||
ChunkStorageMetadata,
|
||||
TensorStorageMetadata,
|
||||
BytesStorageMetadata,
|
||||
MetadataIndex,
|
||||
)
|
||||
|
||||
def _trim(tensor: torch.Tensor) -> torch.Tensor:
|
||||
tensor = tensor.detach()
|
||||
if tensor.storage().size() != tensor.numel():
|
||||
return tensor.clone()
|
||||
return tensor
|
||||
|
||||
def _create_storage_key(
|
||||
storage_key_to_fqn: Dict[str, str],
|
||||
fqn: str
|
||||
) -> str:
|
||||
"""
|
||||
Compute the storage key from the Fully Qualified Name
|
||||
Storage keys must respect the following properties:
|
||||
1) Globally unique name across all objects and ranks.
|
||||
2) Suitable for usage with common storage systems (IE, alphanumeric only)
|
||||
"""
|
||||
|
||||
storage_key = hashlib.sha256(bytes(fqn, "utf-8")).hexdigest()
|
||||
counter = 0
|
||||
while storage_key in storage_key_to_fqn:
|
||||
storage_key = hashlib.sha256(bytes(f"{fqn}{counter}", "utf-8")).hexdigest()
|
||||
counter += 1
|
||||
|
||||
storage_key_to_fqn[storage_key] = fqn
|
||||
return storage_key
|
||||
|
||||
# This constant is used as the separator character between tensor name and shard name
|
||||
STORAGE_KEY_SEPARATOR = "$"
|
||||
|
||||
def _shards_get_overlap_region_wrt_saved_tensor(
|
||||
saved_shard: ShardMetadata, current_shard: ShardMetadata
|
||||
|
|
@ -101,206 +48,3 @@ def _shards_get_overlap_region_wrt_saved_tensor(
|
|||
)
|
||||
|
||||
return narrows
|
||||
|
||||
def _chunk_to_shard_md(chunk_md: ChunkStorageMetadata) -> ShardMetadata:
|
||||
return ShardMetadata(
|
||||
shard_offsets=list(chunk_md.offsets),
|
||||
shard_sizes=list(chunk_md.sizes)
|
||||
)
|
||||
|
||||
def _shard_md_to_chunk(chunk_md: ShardMetadata) -> ChunkStorageMetadata:
|
||||
return ChunkStorageMetadata(
|
||||
offsets=torch.Size(chunk_md.shard_offsets),
|
||||
sizes=torch.Size(chunk_md.shard_sizes),
|
||||
)
|
||||
|
||||
def _compute_sharded_tensor_md(
|
||||
tensor: ShardedTensor,
|
||||
) -> TensorStorageMetadata:
|
||||
smd = [_shard_md_to_chunk(sm) for sm in tensor.metadata().shards_metadata]
|
||||
|
||||
return TensorStorageMetadata(
|
||||
properties=tensor.metadata().tensor_properties,
|
||||
size=torch.Size(tensor.metadata().size),
|
||||
chunks=smd,
|
||||
)
|
||||
|
||||
|
||||
def _get_shard_key(shard: ShardMetadata) -> str:
|
||||
"""
|
||||
Compute an unique key for a shard.
|
||||
|
||||
This key is unique vis-a-vis other shard of the owning ShardedTensor
|
||||
"""
|
||||
return "_".join(str(i) for i in shard.shard_offsets)
|
||||
|
||||
def _get_shard_storage_key(
|
||||
tensor_storage_key: str,
|
||||
shard: ShardMetadata,
|
||||
storage_key_to_fqn: Dict[str, str]
|
||||
) -> str:
|
||||
shard_key = f"{tensor_storage_key}{STORAGE_KEY_SEPARATOR}{_get_shard_key(shard)}"
|
||||
|
||||
return _create_storage_key(storage_key_to_fqn, shard_key)
|
||||
|
||||
|
||||
def _prepare_sharded_tensor_write(
|
||||
fqn: str,
|
||||
sharded_tensor: ShardedTensor,
|
||||
storage_key: str,
|
||||
storage_key_to_fqn: Dict[str, str]
|
||||
) -> Tuple[List[TensorWriteRequest], TensorStorageMetadata, Dict[MetadataIndex, str]]:
|
||||
"""
|
||||
Prepare sharded tensor write.
|
||||
|
||||
Args:
|
||||
fqn: The FQN of ``sharded_tensor`` in the state_dict.
|
||||
sharded_tensor: The sharded tensor to persist.
|
||||
storage_key: The identifier for `sharded_tensor`.
|
||||
storage_key_to_fqn: dict used to produce storage keys
|
||||
Returns:
|
||||
A 3-element tuple with the following values:
|
||||
List of ``TensorWriteRequest`` for the tensor
|
||||
Metadada describing the tensor.
|
||||
Dictionary describing storage information for this tensor
|
||||
|
||||
NB `storage_key` is used to compose the key names of the local shards.
|
||||
"""
|
||||
write_requests = []
|
||||
shard_to_storage_key: Dict[str, str] = {}
|
||||
storage_md = {}
|
||||
|
||||
for shard_md in sharded_tensor.metadata().shards_metadata:
|
||||
shard_storage_key = _get_shard_storage_key(storage_key, shard_md, storage_key_to_fqn)
|
||||
shard_to_storage_key[_get_shard_key(shard_md)] = shard_storage_key
|
||||
storage_md[MetadataIndex(fqn, shard_md.shard_offsets)] = shard_storage_key
|
||||
|
||||
for shard in sharded_tensor.local_shards():
|
||||
tensor = shard.tensor.detach()
|
||||
shard_storage_key = shard_to_storage_key[_get_shard_key(shard.metadata)]
|
||||
|
||||
wr = TensorWriteRequest(
|
||||
tensor=_trim(tensor),
|
||||
storage_key=shard_storage_key,
|
||||
)
|
||||
write_requests.append(wr)
|
||||
return write_requests, _compute_sharded_tensor_md(sharded_tensor), storage_md
|
||||
|
||||
|
||||
def _prepare_sharded_tensor_read(
|
||||
fqn: str,
|
||||
storage_metadata: Dict[MetadataIndex, str],
|
||||
metadata: TensorStorageMetadata,
|
||||
sharded_tensor_out: ShardedTensor
|
||||
) -> List[TensorReadRequest]:
|
||||
"""
|
||||
Prepare sharded tensor read.
|
||||
|
||||
Args:
|
||||
fqn: The FQN of ``sharded_tensor`` in the state_dict.
|
||||
storage_metadata: Dictionary describing checkpoint storage.
|
||||
metadata: Metadata describing the persisted sharded tensor. Normally,
|
||||
this is generated by func::`_prepare_sharded_tensor_write`.
|
||||
sharded_tensor_out: The ShardedTensor being read.
|
||||
|
||||
Returns:
|
||||
A list of class::`TensorReadRequest`. When fullfilled,
|
||||
`sharded_tensor_out`'s local shards load from the persisted sharded
|
||||
tensor.
|
||||
"""
|
||||
return _prepare_generic_tensor_read(
|
||||
fqn,
|
||||
metadata.chunks,
|
||||
sharded_tensor_out.local_shards(),
|
||||
storage_metadata)
|
||||
|
||||
def _prepare_generic_tensor_read(
|
||||
fqn: str,
|
||||
checkpoint_shards: List[ChunkStorageMetadata],
|
||||
local_shards: List[Shard],
|
||||
storage_metadata: Dict[MetadataIndex, str]
|
||||
) -> List[TensorReadRequest]:
|
||||
read_reqs = []
|
||||
# this is a naive quadratic algo that can be optimized later
|
||||
for shard in local_shards:
|
||||
# scan all mds looking for chunks
|
||||
for storage_md in checkpoint_shards:
|
||||
shard_md_from_storage = _chunk_to_shard_md(storage_md)
|
||||
|
||||
# do they overlap?
|
||||
if not _check_shard_metadata_pair_overlap(
|
||||
shard.metadata, shard_md_from_storage
|
||||
):
|
||||
continue
|
||||
|
||||
storage_key = storage_metadata[MetadataIndex(fqn, storage_md.offsets)]
|
||||
target_tensor = shard.tensor.detach()
|
||||
offsets = []
|
||||
lengths = []
|
||||
for (
|
||||
dim,
|
||||
offset_for_saved_tensor,
|
||||
offset_for_current_tensor,
|
||||
length,
|
||||
) in _shards_get_overlap_region_wrt_saved_tensor(
|
||||
saved_shard=shard_md_from_storage, current_shard=shard.metadata
|
||||
):
|
||||
# Note that we do NOT want to make any tensor copy.
|
||||
# all operation must be view only
|
||||
target_tensor = torch.narrow(
|
||||
target_tensor, dim, offset_for_current_tensor, length
|
||||
)
|
||||
offsets.append(offset_for_saved_tensor)
|
||||
lengths.append(length)
|
||||
|
||||
read_reqs.append(
|
||||
TensorReadRequest(
|
||||
tensor=target_tensor,
|
||||
storage_key=storage_key,
|
||||
offsets=tuple(offsets),
|
||||
lengths=tuple(lengths),
|
||||
)
|
||||
)
|
||||
return read_reqs
|
||||
|
||||
def _compute_tensor_md(tensor: Tensor) -> TensorStorageMetadata:
|
||||
return TensorStorageMetadata(
|
||||
properties=TensorProperties.create_from_tensor(tensor),
|
||||
size=tensor.size(),
|
||||
chunks=[ChunkStorageMetadata(
|
||||
offsets=torch.Size([0] * len(tensor.shape)),
|
||||
sizes=tensor.size(),
|
||||
)]
|
||||
)
|
||||
|
||||
def _prepare_tensor_write(
|
||||
tensor: Tensor, fqn: str, storage_key_to_fqn: Dict[str, str]
|
||||
) -> Tuple[List[TensorWriteRequest], TensorStorageMetadata, Dict[MetadataIndex, str]]:
|
||||
storage_key = _create_storage_key(storage_key_to_fqn, fqn)
|
||||
storage_md = {MetadataIndex(fqn, [0] * len(tensor.shape)): storage_key}
|
||||
write_reqs = [
|
||||
TensorWriteRequest(
|
||||
tensor=_trim(tensor),
|
||||
storage_key=storage_key,
|
||||
)
|
||||
]
|
||||
return (write_reqs, _compute_tensor_md(tensor), storage_md)
|
||||
|
||||
|
||||
def _compute_bytes_md(bytes: io.BytesIO) -> BytesStorageMetadata:
|
||||
return BytesStorageMetadata(
|
||||
)
|
||||
|
||||
def _prepare_bytes_write(
|
||||
bytes: io.BytesIO, fqn: str, storage_key_to_fqn: Dict[str, str]
|
||||
) -> Tuple[List[BytesWriteRequest], BytesStorageMetadata, Dict[MetadataIndex, str]]:
|
||||
storage_key = _create_storage_key(storage_key_to_fqn, fqn)
|
||||
storage_md = {MetadataIndex(fqn): storage_key}
|
||||
|
||||
write_reqs = [
|
||||
BytesWriteRequest(
|
||||
bytes=bytes,
|
||||
storage_key=storage_key,
|
||||
)
|
||||
]
|
||||
return (write_reqs, _compute_bytes_md(bytes), storage_md)
|
||||
|
|
|
|||
|
|
@ -1,96 +1,22 @@
|
|||
import io
|
||||
from typing import Any, Dict, List, Tuple, Optional, cast
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
from torch.distributed._shard.sharded_tensor.shard import Shard
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
)
|
||||
|
||||
from .metadata import (
|
||||
BytesReadRequest,
|
||||
BytesStorageMetadata,
|
||||
TensorReadRequest,
|
||||
TensorStorageMetadata,
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
)
|
||||
from .resharding import (
|
||||
_prepare_generic_tensor_read,
|
||||
)
|
||||
from .storage import (
|
||||
StorageReader,
|
||||
)
|
||||
from .planner import LoadPlanner
|
||||
from .default_planner import DefaultLoadPlanner
|
||||
|
||||
from .utils import _DistWrapper
|
||||
|
||||
def _create_shard_metadata(size: torch.Size) -> ShardMetadata:
|
||||
return ShardMetadata(
|
||||
shard_offsets=[0] * len(size),
|
||||
shard_sizes=list(size),
|
||||
)
|
||||
|
||||
def _create_shard_for(tensor: Tensor) -> Shard:
|
||||
return Shard(
|
||||
tensor=tensor,
|
||||
metadata=_create_shard_metadata(tensor.size()),
|
||||
)
|
||||
|
||||
def _reshard_and_prepare_read_request(
|
||||
state_dict: Dict[str, Any], metadata_from_storage: Metadata
|
||||
) -> Tuple[List[BytesReadRequest], List[TensorReadRequest]]:
|
||||
"""
|
||||
Use the loaded metadata and the current state dict to map the saved tensors to current tensor
|
||||
"""
|
||||
tensor_read_requests = []
|
||||
bytes_read_requests = []
|
||||
storage_md = cast(Dict[MetadataIndex, str], metadata_from_storage.storage_data)
|
||||
for fqn, obj in state_dict.items():
|
||||
md = metadata_from_storage.state_dict_metadata[fqn]
|
||||
if isinstance(obj, ShardedTensor):
|
||||
local_shards = obj.local_shards()
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
local_shards = [_create_shard_for(obj)]
|
||||
else:
|
||||
if isinstance(md, BytesStorageMetadata):
|
||||
bytes_io = io.BytesIO()
|
||||
brr = BytesReadRequest(
|
||||
bytes=bytes_io,
|
||||
storage_key=storage_md[MetadataIndex(fqn)],
|
||||
fqn=fqn
|
||||
)
|
||||
bytes_read_requests.append(brr)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid checkpoint metadata for {fqn}, " +
|
||||
f"expected BytesStorageMetadata but found {type(md)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(md, TensorStorageMetadata):
|
||||
checkpoint_shards = md.chunks
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid checkpoint metadata for {fqn}, " +
|
||||
f"expected TensorStorageMetadata but found {type(md)}"
|
||||
)
|
||||
|
||||
tensor_read_requests += _prepare_generic_tensor_read(fqn, checkpoint_shards, local_shards, storage_md)
|
||||
|
||||
|
||||
|
||||
return (bytes_read_requests, tensor_read_requests)
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
storage_reader: StorageReader,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
coordinator_rank: int = 0,
|
||||
no_dist: bool = False
|
||||
no_dist: bool = False,
|
||||
planner: LoadPlanner = None
|
||||
) -> None:
|
||||
"""
|
||||
Load a distributed state_dict in SPMD style.
|
||||
|
|
@ -150,25 +76,34 @@ def load_state_dict(
|
|||
has an individual GPU, via ``torch.cuda.set_device()``
|
||||
"""
|
||||
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
|
||||
if planner is None:
|
||||
planner = DefaultLoadPlanner()
|
||||
|
||||
def load_model():
|
||||
|
||||
def local_step():
|
||||
assert planner is not None
|
||||
metadata = storage_reader.read_metadata()
|
||||
bytes_read_requests, tensor_read_requests = _reshard_and_prepare_read_request(
|
||||
state_dict=state_dict, metadata_from_storage=metadata
|
||||
)
|
||||
bytes_futures = storage_reader.read_bytes(bytes_read_requests)
|
||||
tensor_futures = storage_reader.read_tensors(tensor_read_requests)
|
||||
planner.init(state_dict, metadata, distW.is_coordinator)
|
||||
storage_reader.init(metadata, distW.is_coordinator)
|
||||
|
||||
bytes_futures.wait()
|
||||
local_plan = planner.create_local_plan()
|
||||
local_plan = storage_reader.prepare_local_plan(local_plan)
|
||||
return local_plan
|
||||
|
||||
# Addtional steps are required to convert the bytes to its original type
|
||||
# Note that this is NOT inplace,
|
||||
# it creating a new object and replace what's in the state dict
|
||||
for req in bytes_read_requests:
|
||||
# Ensure the BytesIO is rewound
|
||||
req.bytes.seek(0)
|
||||
state_dict[req.fqn] = torch.load(req.bytes)
|
||||
def global_step(all_local_plans):
|
||||
assert planner is not None
|
||||
all_local_plans = planner.create_global_plan(all_local_plans)
|
||||
all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
|
||||
return all_local_plans
|
||||
|
||||
tensor_futures.wait()
|
||||
central_plan = distW.reduce_scatter("plan", local_step, global_step)
|
||||
|
||||
distW.all_gather("checkpoint read", load_model)
|
||||
def read_data():
|
||||
assert planner is not None
|
||||
final_local_plan = planner.finish_plan(central_plan)
|
||||
all_reads = storage_reader.read_data(final_local_plan, planner)
|
||||
|
||||
all_reads.wait()
|
||||
return None
|
||||
|
||||
_ = distW.all_gather("read", read_data)
|
||||
|
|
|
|||
|
|
@ -1,104 +1,29 @@
|
|||
import io
|
||||
from typing import Any, Dict, List, Tuple, Optional, Union
|
||||
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
)
|
||||
from .planner import SavePlanner
|
||||
from .default_planner import DefaultSavePlanner
|
||||
|
||||
from .metadata import (
|
||||
Metadata,
|
||||
BytesWriteRequest,
|
||||
TensorWriteRequest,
|
||||
)
|
||||
from .resharding import (
|
||||
_prepare_sharded_tensor_write,
|
||||
_prepare_tensor_write,
|
||||
_prepare_bytes_write
|
||||
)
|
||||
|
||||
from .storage import (
|
||||
StorageWriter,
|
||||
)
|
||||
|
||||
from .metadata import (
|
||||
Metadata,
|
||||
STATE_DICT_TYPE
|
||||
)
|
||||
from .utils import _DistWrapper
|
||||
|
||||
|
||||
# -------------- private functions --------------
|
||||
|
||||
def _prepare(
|
||||
state_dict: Dict[str, Any],
|
||||
write_replicated_data: bool,
|
||||
) -> Tuple[Metadata, List[BytesWriteRequest], List[TensorWriteRequest]]:
|
||||
"""
|
||||
Build the serialization plan for a given state_dict
|
||||
|
||||
Args:
|
||||
state_dict: The instance to plan for.
|
||||
|
||||
Returns:
|
||||
A tuple with the following values:
|
||||
|
||||
metadata: Metadata
|
||||
The storage metadata describing Tensor and ShardedTensors
|
||||
instances found in `state_dict`. See `Metadata` for the schema.
|
||||
|
||||
size_for_storage_keys: Dict[str, int]
|
||||
Key is the storage key name, value is the associated size
|
||||
It can used to pre allocate the storage for parallel and non sequential writes.
|
||||
|
||||
bytes_write_requests: List[BytesWriteRequest]
|
||||
List of ByteIO write requests that should be performed by the writer.
|
||||
|
||||
tensor_write_requests: List[TensorWriteRequest]
|
||||
List of Tensor write requests that should be performed by the writer.
|
||||
|
||||
"""
|
||||
metadata = Metadata(state_dict_metadata={})
|
||||
tensor_write_requests: List[TensorWriteRequest] = []
|
||||
bytes_write_requests: List[BytesWriteRequest] = []
|
||||
storage_key_to_fqn: Dict[str, str] = {}
|
||||
|
||||
storage_md = {}
|
||||
|
||||
for fqn, obj in state_dict.items():
|
||||
if isinstance(obj, ShardedTensor):
|
||||
st_write_reqs, st_md, storage_data = _prepare_sharded_tensor_write(fqn, obj, fqn, storage_key_to_fqn)
|
||||
tensor_write_requests += st_write_reqs
|
||||
metadata.state_dict_metadata[fqn] = st_md
|
||||
storage_md.update(storage_data)
|
||||
elif isinstance(obj, Tensor):
|
||||
write_reqs, tensor_md, storage_data = _prepare_tensor_write(obj, fqn, storage_key_to_fqn)
|
||||
if write_replicated_data:
|
||||
tensor_write_requests += write_reqs
|
||||
metadata.state_dict_metadata[fqn] = tensor_md
|
||||
storage_md.update(storage_data)
|
||||
else:
|
||||
bytes_io = io.BytesIO()
|
||||
# This produces incomplete MD for rank > 0 since we won't populate bytes_io.
|
||||
# This is ok since only rank == 0 uses this data
|
||||
if write_replicated_data:
|
||||
torch.save(obj, bytes_io)
|
||||
byte_write_reqs, bytes_md, storage_data = _prepare_bytes_write(bytes_io, fqn, storage_key_to_fqn)
|
||||
if write_replicated_data:
|
||||
bytes_write_requests += byte_write_reqs
|
||||
metadata.state_dict_metadata[fqn] = bytes_md
|
||||
storage_md.update(storage_data)
|
||||
|
||||
metadata.storage_data = storage_md
|
||||
return (metadata, bytes_write_requests, tensor_write_requests)
|
||||
|
||||
def save_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: STATE_DICT_TYPE,
|
||||
storage_writer: StorageWriter,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
coordinator_rank: int = 0,
|
||||
no_dist: bool = False
|
||||
) -> None:
|
||||
no_dist: bool = False,
|
||||
planner: SavePlanner = None
|
||||
) -> Metadata:
|
||||
"""
|
||||
Save a distributed model in SPMD style.
|
||||
|
||||
|
|
@ -149,29 +74,41 @@ def save_state_dict(
|
|||
has an individual GPU, via ``torch.cuda.set_device()``
|
||||
"""
|
||||
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
|
||||
if planner is None:
|
||||
planner = DefaultSavePlanner()
|
||||
assert planner is not None
|
||||
|
||||
distW.broadcast("prepare", storage_writer.prepare)
|
||||
metadata = None
|
||||
global_metatadata = None
|
||||
|
||||
def write_step():
|
||||
nonlocal metadata
|
||||
(
|
||||
metadata,
|
||||
bytes_write_requests,
|
||||
tensor_write_requests,
|
||||
) = _prepare(state_dict, distW.is_coordinator)
|
||||
def local_step():
|
||||
assert planner is not None
|
||||
planner.init(state_dict, distW.is_coordinator)
|
||||
storage_writer.init(distW.is_coordinator)
|
||||
local_plan = planner.create_local_plan()
|
||||
local_plan = storage_writer.prepare_local_plan(local_plan)
|
||||
return local_plan
|
||||
|
||||
combined_writes: List[Union[TensorWriteRequest, BytesWriteRequest]] = []
|
||||
combined_writes.extend(tensor_write_requests)
|
||||
combined_writes.extend(bytes_write_requests)
|
||||
def global_step(all_local_plans):
|
||||
nonlocal global_metatadata
|
||||
|
||||
storage_writer.prepare_storage(combined_writes)
|
||||
bytes_futures = storage_writer.write_bytes(bytes_write_requests)
|
||||
tensor_futures = storage_writer.write_tensors(tensor_write_requests)
|
||||
torch.futures.wait_all([bytes_futures, tensor_futures])
|
||||
assert planner is not None
|
||||
all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans)
|
||||
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
|
||||
return all_local_plans
|
||||
|
||||
def finish_checkpoint(_):
|
||||
assert metadata is not None
|
||||
storage_writer.finish(metadata=metadata)
|
||||
central_plan = distW.reduce_scatter("plan", local_step, global_step)
|
||||
|
||||
distW.all_reduce("checkpoitn write", write_step, finish_checkpoint)
|
||||
def write_data():
|
||||
assert planner is not None
|
||||
final_local_plan = planner.finish_plan(central_plan)
|
||||
all_writes = storage_writer.write_data(final_local_plan, planner)
|
||||
|
||||
all_writes.wait()
|
||||
return all_writes.value()
|
||||
|
||||
def finish_checkpoint(all_results):
|
||||
assert global_metatadata is not None
|
||||
storage_writer.finish(metadata=global_metatadata, results=all_results)
|
||||
return global_metatadata
|
||||
|
||||
return distW.all_reduce("write", write_data, finish_checkpoint)
|
||||
|
|
|
|||
|
|
@ -1,188 +1,227 @@
|
|||
import abc
|
||||
from typing import List, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any
|
||||
|
||||
from torch.futures import Future
|
||||
|
||||
from .metadata import (
|
||||
BytesReadRequest,
|
||||
BytesWriteRequest,
|
||||
Metadata,
|
||||
TensorReadRequest,
|
||||
TensorWriteRequest,
|
||||
MetadataIndex,
|
||||
)
|
||||
|
||||
from .planner import (
|
||||
LoadPlan,
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
LoadPlanner,
|
||||
)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WriteResult:
|
||||
index: MetadataIndex
|
||||
|
||||
size_in_bytes: int
|
||||
storage_data: Any
|
||||
|
||||
class StorageWriter(abc.ABC):
|
||||
"""
|
||||
Interface used by ``save_state_dict`` to write to storage.
|
||||
|
||||
A subclass should expect the following sequence of calls by ``save_state_dict``
|
||||
|
||||
1) (called once globally) prepare()
|
||||
2) prepare_storage() with the writes that will be used with (3) and (4).
|
||||
3) write_bytes
|
||||
4) write_tensors.
|
||||
5) Wait for (2) and (3) futures. If either fail, abort checkpoint.
|
||||
6) (called once globally) finish().
|
||||
|
||||
There's a single process that executes methods that are called once globally.
|
||||
The writes from (3) and (4) are initiated before any waiting is done.
|
||||
The last call to finish() has the semantics of commiting the checkpoint.
|
||||
One StorageWriter instance acts as both the coordinator and the follower
|
||||
in a distributed checkpoint. As part of initialization, each instance
|
||||
is told its role.
|
||||
|
||||
A subclass should expect the following sequence of calls.
|
||||
|
||||
1) (all ranks) init()
|
||||
2) (all ranks) prepare_local_plan()
|
||||
3) (coordinator) prepare_global_plan()
|
||||
4) (all ranks) write_data()
|
||||
5) (coordinator) finish()
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def prepare(self) -> None:
|
||||
"""
|
||||
Initialize storage to receive the checkpoint.
|
||||
|
||||
This method is called once globally per checkpoint before any other method.
|
||||
This is in contrast to ``prepare_storage`` which is called on each process
|
||||
in parallel.
|
||||
|
||||
Returns:
|
||||
Future to signal intialization is complete.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_bytes(self, requests: List[BytesWriteRequest]) -> Future[None]:
|
||||
def init(self, is_coordinator: bool) -> None:
|
||||
"""
|
||||
Initiate writes for all requests in `requests`.
|
||||
|
||||
Writing can happen asynchronously and/or concurrently. A blocking
|
||||
implementation is valid.
|
||||
Initialize this instance.
|
||||
|
||||
Args:
|
||||
requests (List[BytesWriteRequest]): A list of requests to write
|
||||
Returns:
|
||||
A future that completes once all writes have finished.
|
||||
is_coordinator (bool): Whether this instance is reponsible for coordinating
|
||||
the checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_tensors(self, requests: List[TensorWriteRequest]) -> Future[None]:
|
||||
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
|
||||
"""
|
||||
Initiate writes for all requests in `requests`.
|
||||
Perform storage-specific local planning.
|
||||
|
||||
Writing can happen asynchronously and/or concurrently. A blocking
|
||||
implementation is valid.
|
||||
|
||||
Implementors are responsible for any device to host transfers required
|
||||
to copy.
|
||||
While this method can produce a completely different plan, the recomended
|
||||
way is to store storage specific data in SavePlan::storage_data.
|
||||
|
||||
Args:
|
||||
requests (List[TensorWriteRequest]): A list of requests to write
|
||||
plan (SavePlan): The local plan from the ``SavePlanner`` in use.
|
||||
|
||||
Returns:
|
||||
A future that completes once all writes have finished.
|
||||
A transformed ``SavePlan`` after storage local planning
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def finish(self, metadata: Metadata) -> None:
|
||||
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
|
||||
"""
|
||||
Writes the metadata and marks the current checkpoint as sucessfull.
|
||||
Perform centralized planning of storage.
|
||||
|
||||
This method is called once globally after all data was writen
|
||||
and is used to write its metadata and commit the checkpoint.
|
||||
This method is only called on the coordinator instance.
|
||||
|
||||
The `metadata` object includes a global view of the checkpoint
|
||||
and, while writing it is optional, it must be recoverable by the
|
||||
StorageReader implementation.
|
||||
While this method can produce a completely different plan, the prefered
|
||||
way is to store storage specific data in SavePlan::storage_data.
|
||||
|
||||
The actual format/schema used for serializing `metadata` is
|
||||
considered and implementation detail.
|
||||
Args:
|
||||
plans: A list of ``SavePlan`` instances, one for each rank.
|
||||
|
||||
Returns:
|
||||
A list of transformed ``SavePlan`` after storage global planning
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_data(
|
||||
self,
|
||||
plan: SavePlan,
|
||||
planner: SavePlanner
|
||||
) -> Future[List[WriteResult]]:
|
||||
"""
|
||||
Write all items from ``plan`` using ``planner`` to resolve the data.
|
||||
|
||||
A subclass should call ``SavePlanner::resolve_data`` on each item
|
||||
from the plan to get access to the underlying object to write.
|
||||
|
||||
Subclasses should lazily call `resolve_data` as it can allocate memory.
|
||||
In case of tensors, make following assuptions:
|
||||
|
||||
- They might be on any device, including not matching the one on ``WriteItem::tensor_data``
|
||||
- They might be views or not contiguous. Only the projection needs to be saved.
|
||||
|
||||
Args:
|
||||
plan (SavePlan): The save plan to execute.
|
||||
planner (SavePlanner): Planner object to be used to resolve items to data.
|
||||
|
||||
Returns:
|
||||
A future that completes to a list of WriteResult
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
"""
|
||||
Writes the metadata and marks the current checkpoint as sucessful.
|
||||
|
||||
The actual format/schema used for serializing `metadata` is an
|
||||
implemetation detail. The only requirement is that it's recoverable
|
||||
in to the same object graph.
|
||||
|
||||
Args:
|
||||
metadata (Metadata): metadata for the new checkpoint
|
||||
results: A list of WriteResults from all ranks.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_storage(self, storage_writes: List[Union[TensorWriteRequest, BytesWriteRequest]]) -> None:
|
||||
"""
|
||||
Prepare the underlying storage for upcoming writes.
|
||||
|
||||
This is an optional override intended for advanced scenarios where
|
||||
a storage layer needs wants to do some work ahead of the writing itself.
|
||||
|
||||
This method is called on each process in parallel before any writes are performed.
|
||||
|
||||
The default implementation does nothing.
|
||||
|
||||
Args:
|
||||
storage_writes (List[Union[TensorWriteRequest, BytesWriteRequest]]): A list of
|
||||
all writes that will be submited.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StorageReader(abc.ABC):
|
||||
"""
|
||||
Interface used by ``load_state_dict`` to read from storage.
|
||||
|
||||
One StorageReader instance acts as both the coordinator and the follower
|
||||
in a distributed checkpoint. As part of initialization, each instance
|
||||
is told its role.
|
||||
|
||||
A subclass should expected the following sequence of calls by ``load_state_dict``:
|
||||
|
||||
1) read_metadata() - on all ranks
|
||||
2) read_bytes
|
||||
3) read_tensors
|
||||
|
||||
The reads from (2) and (3) are initiated before any waiting is done.
|
||||
|
||||
Implementors must ensure host/device synchronization as part of
|
||||
completion of both read requests.
|
||||
1) (all ranks) read_metadata()
|
||||
2) (all ranks) init
|
||||
3) (all ranks) prepare_local_plan
|
||||
4) (coordinator) prepare_global_plan
|
||||
5) (all ranks) read_data
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_bytes(self, requests: List[BytesReadRequest]) -> Future[None]:
|
||||
"""
|
||||
Initiate read for all requests in `requests`.
|
||||
|
||||
Reading happen asynchronously and/or concurrently. A blocking
|
||||
implementation is valid.
|
||||
|
||||
Args:
|
||||
requests (List[BytesReadRequest]): A list of requests to read.
|
||||
|
||||
Return:
|
||||
A future that completes once all read have finished.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_tensors(self, requests: List[TensorReadRequest]) -> Future[None]:
|
||||
"""
|
||||
Initiate read for all requests in `requests`.
|
||||
|
||||
Reading happen asynchronously and/or concurrently. A blocking
|
||||
implementation is valid.
|
||||
|
||||
Implementors must not assume that the original device
|
||||
at write time will be the same at read time.
|
||||
|
||||
If an implementation uses asynchronous copies to device, it must
|
||||
ensure proper synchronization W.R.T. the returned future.
|
||||
|
||||
Args:
|
||||
requests (List[BytesReadRequest]): A list of requests to read.
|
||||
|
||||
Returns:
|
||||
A future that completes once all read have finished.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_metadata(self) -> Metadata:
|
||||
"""
|
||||
Reads the checkpoint metadata.
|
||||
|
||||
Returnss:
|
||||
Returns:
|
||||
The metatada object associated with the checkpoint being loaded.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def init(self, metadata: Metadata, is_coordinator: bool) -> None:
|
||||
"""
|
||||
Initialize this instance.
|
||||
|
||||
Args:
|
||||
metadata (Metadata): The metadata schema to use.
|
||||
is_coordinator (bool): Whether this instance is reponsible for coordinating
|
||||
the checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
||||
"""
|
||||
Perform storage-specific local planning.
|
||||
|
||||
While this method can produce a completely different plan, the recomended
|
||||
way is to store storage specific data in LoadPlan::storage_data.
|
||||
|
||||
Args:
|
||||
plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
|
||||
|
||||
Returns:
|
||||
A transformed ``LoadPlan`` after storage local planning
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
|
||||
"""
|
||||
Perform centralized planning of storage loading.
|
||||
|
||||
This method is only called on the coordinator instance.
|
||||
|
||||
While this method can produce a completely different plan, the prefered
|
||||
way is to store storage specific data in LoadPlan::storage_data.
|
||||
|
||||
Args:
|
||||
plans: A list of ``LoadPlan`` instances, one for each rank.
|
||||
|
||||
Returns:
|
||||
A list of transformed ``LoadPlan`` after storage global planning
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
|
||||
"""
|
||||
Reads all items from ``plan`` using ``planner`` to resolve the data.
|
||||
|
||||
A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
|
||||
object into the right place.
|
||||
|
||||
A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
|
||||
tensors that in should load data into.
|
||||
|
||||
It's the StorageLayer responsibility to properly schedule any cross device copies
|
||||
required.
|
||||
|
||||
Args:
|
||||
plan (LoadPlan): The local plan to execute on
|
||||
planner (LoadPlanner): The planner object to use to resolve items.
|
||||
|
||||
Returns:
|
||||
A future that completes once all reads are finished.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -265,15 +265,22 @@ def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
|
|||
return shard
|
||||
raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
|
||||
|
||||
def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
|
||||
if isinstance(tensor, ShardedTensor):
|
||||
return _find_shard(tensor, index).tensor
|
||||
if index.offset is not None:
|
||||
# special case looking up a tensor by origin
|
||||
if index.offset == torch.Size([0] * len(tensor.size())):
|
||||
return tensor
|
||||
raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'")
|
||||
return tensor
|
||||
|
||||
def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
|
||||
if index.fqn not in state_dict:
|
||||
raise ValueError(f"Could not find FQN: '{index.fqn}'")
|
||||
obj = state_dict[index.fqn]
|
||||
if isinstance(obj, ShardedTensor):
|
||||
return _find_shard(obj, index).tensor
|
||||
if index.offset is not None:
|
||||
# special case looking up a tensor by origin
|
||||
if isinstance(obj, torch.Tensor) and index.offset == torch.Size([0] * len(obj.size())):
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return find_tensor_shard(obj, index)
|
||||
elif index.offset is not None:
|
||||
raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'")
|
||||
return obj
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user