[DCP][BE] Apply ufmt to DCP and turn on lintrunner for DCP (#115302)

No logic change. Just typing and ufmt.

Differential Revision: [D51914982](https://our.internmc.facebook.com/intern/diff/D51914982/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115302
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/LucasLLC
ghstack dependencies: #115523
This commit is contained in:
Chien-Chin Huang 2023-12-12 16:32:07 -08:00 committed by PyTorch MergeBot
parent cc28f61fa3
commit db8d409d08
31 changed files with 374 additions and 639 deletions

View File

@ -869,7 +869,6 @@ exclude_patterns = [
'test/bottleneck_test/**', # excluded by test/run_test.py
'test/distributed/argparse_util_test.py',
'test/distributed/bin/test_script.py',
'test/distributed/checkpoint/e2e/test_pipeline.py',
'test/distributed/elastic/agent/server/test/local_elastic_agent_test.py',
'test/distributed/elastic/multiprocessing/bin/test_script.py',
'test/distributed/elastic/multiprocessing/bin/zombie_test.py',
@ -1094,19 +1093,6 @@ exclude_patterns = [
'test/distributed/algorithms/test_join.py',
'test/distributed/argparse_util_test.py',
'test/distributed/bin/test_script.py',
'test/distributed/checkpoint/test_2d_fsdp_dt_checkpoint.py',
'test/distributed/checkpoint/test_checkpoint.py',
'test/distributed/checkpoint/test_dedup_tensors.py',
'test/distributed/checkpoint/test_dtensor_checkpoint.py',
'test/distributed/checkpoint/test_file_system_checkpoint.py',
'test/distributed/checkpoint/test_file_system_checkpoint_cpu.py',
'test/distributed/checkpoint/test_fsdp_model_state.py',
'test/distributed/checkpoint/test_fsdp_optim_state.py',
'test/distributed/checkpoint/test_fsspec.py',
'test/distributed/checkpoint/test_nested_dict.py',
'test/distributed/checkpoint/test_planner.py',
'test/distributed/checkpoint/test_traverse.py',
'test/distributed/checkpoint/test_utils.py',
'test/distributed/elastic/agent/server/test/__init__.py',
'test/distributed/elastic/agent/server/test/api_test.py',
'test/distributed/elastic/agent/server/test/local_elastic_agent_test.py',
@ -2010,25 +1996,6 @@ exclude_patterns = [
'torch/distributed/autograd/__init__.py',
'torch/distributed/benchmarks/benchmark_ddp_rpc.py',
'torch/distributed/c10d_logger.py',
'torch/distributed/checkpoint/__init__.py',
'torch/distributed/checkpoint/_dedup_tensors.py',
'torch/distributed/checkpoint/_fsspec_filesystem.py',
'torch/distributed/checkpoint/_nested_dict.py',
'torch/distributed/checkpoint/_sharded_tensor_utils.py',
'torch/distributed/checkpoint/_traverse.py',
'torch/distributed/checkpoint/api.py',
'torch/distributed/checkpoint/default_planner.py',
'torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py',
'torch/distributed/checkpoint/filesystem.py',
'torch/distributed/checkpoint/metadata.py',
'torch/distributed/checkpoint/optimizer.py',
'torch/distributed/checkpoint/planner.py',
'torch/distributed/checkpoint/planner_helpers.py',
'torch/distributed/checkpoint/resharding.py',
'torch/distributed/checkpoint/state_dict_loader.py',
'torch/distributed/checkpoint/state_dict_saver.py',
'torch/distributed/checkpoint/storage.py',
'torch/distributed/checkpoint/utils.py',
'torch/distributed/collective_utils.py',
'torch/distributed/constants.py',
'torch/distributed/distributed_c10d.py',
@ -2442,7 +2409,6 @@ exclude_patterns = [
'torch/testing/_internal/distributed/_shard/test_common.py',
'torch/testing/_internal/distributed/_tensor/__init__.py',
'torch/testing/_internal/distributed/_tensor/common_dtensor.py',
'torch/testing/_internal/distributed/checkpoint_utils.py',
'torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py',
'torch/testing/_internal/distributed/distributed_test.py',
'torch/testing/_internal/distributed/distributed_utils.py',

View File

@ -10,7 +10,7 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_di
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
@ -102,3 +102,7 @@ class TestPipeline(FSDPTest):
self.assertTrue(os.path.exists(pipeline_dir))
self.save_with_pipeline(pipeline_dir)
self.load_with_fsdp(pipeline_dir)
if __name__ == "__main__":
run_tests()

View File

@ -1,29 +1,28 @@
# Owner(s): ["oncall: distributed"]
import sys
from typing import Optional, List, cast
from torch.distributed.checkpoint.storage import WriteResult
from torch.distributed.checkpoint import (
StorageReader,
StorageWriter,
CheckpointException,
load_state_dict,
save_state_dict,
)
from typing import cast, List, Optional
import torch
import torch.distributed as dist
import torch.nn
import torch.futures
from torch.futures import Future
import torch.nn
from torch.distributed._shard import sharded_tensor
from torch.distributed.checkpoint.default_planner import (
_create_default_local_metadata,
from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.checkpoint import (
CheckpointException,
load_state_dict,
save_state_dict,
StorageReader,
StorageWriter,
)
from torch.distributed.checkpoint.default_planner import _create_default_local_metadata
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
Metadata,
@ -31,31 +30,21 @@ from torch.distributed.checkpoint.metadata import (
)
from torch.distributed.checkpoint.planner import (
SavePlan,
SavePlanner,
LoadPlan,
LoadPlanner,
SavePlan,
SavePlanner,
)
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.distributed._shard.sharded_tensor import (
state_dict_hook,
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -175,9 +164,7 @@ class TestStorageBase:
ranks = self._get_ranks(name)
fut = Future()
if ranks is not None and self.rank in ranks:
fut.set_exception(
ValueError(f"async rank fail {self.rank} for {name}")
)
fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
else:
fut.set_result(result)
return fut
@ -204,9 +191,7 @@ class FaultyStorageWriter(TestStorageBase, StorageWriter):
self._fail_rank("fail_write_data")
return self._fail_rank_async("fail_write_data_async", [])
def finish(
self, metadata: Metadata, results: List[List[WriteResult]]
) -> None:
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
self._fail_rank("fail_finish")
@ -239,9 +224,7 @@ class TestDistributedFailure(ShardedTensorTestBase):
def get_spec(self):
return ChunkShardingSpec(
dim=0,
placements=[
f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
],
placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
)
@with_comms(init_rpc=False)

View File

@ -1,12 +1,11 @@
# Owner(s): ["oncall: distributed"]
import dataclasses
import torch
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
from torch.distributed.checkpoint.planner import SavePlan, WriteItemType
from torch.distributed.checkpoint.planner_helpers import (
_create_write_item_for_tensor,
)
from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
from torch.testing._internal.common_utils import run_tests, TestCase
@ -33,9 +32,7 @@ class TestDedupTensor(TestCase):
self.assertEqual(2, len(dedup_plans[0].items))
self.assertEqual(1, len(dedup_plans[1].items))
self.assertIn(
"tensor_0", (item.index.fqn for item in dedup_plans[0].items)
)
self.assertIn("tensor_0", (item.index.fqn for item in dedup_plans[0].items))
self.assertIn("r0", (item.index.fqn for item in dedup_plans[0].items))
self.assertIn("r1", (item.index.fqn for item in dedup_plans[1].items))

View File

@ -6,19 +6,19 @@ import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
from torch.distributed._tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
distribute_tensor,
zeros,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
SUBMESH_TENSOR_SIZE = 6
@ -81,9 +81,7 @@ class DTensorPlanner(DTensorTestBase):
device_type=self.device_type,
mesh=range(dist.get_world_size()),
)
sharded_dt = distribute_tensor(
tensor_to_shard, mesh, placements=[Shard(0)]
)
sharded_dt = distribute_tensor(tensor_to_shard, mesh, placements=[Shard(0)])
replicated_dt = distribute_tensor(
tensor_to_replicate, mesh, placements=[Replicate()]
)
@ -179,9 +177,7 @@ class DTensorPlanner(DTensorTestBase):
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
planner=dist_cp.DefaultSavePlanner(),
)
model, _, _ = self.create_dtensor_model(
local_tensor * 10, local_tensor_2 * 10
)
model, _, _ = self.create_dtensor_model(local_tensor * 10, local_tensor_2 * 10)
state_dict = model.state_dict()
"""
@ -247,9 +243,7 @@ class DTensorPlanner(DTensorTestBase):
if k == "submesh_sdt":
if self.rank % 2 == 0:
shard_size = int(SUBMESH_TENSOR_SIZE / v.device_mesh.size())
self.assertEqual(
v.to_local().size(), torch.Size([shard_size])
)
self.assertEqual(v.to_local().size(), torch.Size([shard_size]))
self.assertEqual(v.to_local(), torch.zeros([shard_size]))
else:
self.assertEqual(v.to_local().size(), torch.Size([0]))
@ -258,9 +252,7 @@ class DTensorPlanner(DTensorTestBase):
if k == "submesh_rdt":
if self.rank % 2 == 0:
shard_size = SUBMESH_TENSOR_SIZE
self.assertEqual(
v.to_local().size(), torch.Size([shard_size])
)
self.assertEqual(v.to_local().size(), torch.Size([shard_size]))
self.assertEqual(v.to_local(), torch.zeros([shard_size]))
else:
self.assertEqual(v.to_local().size(), torch.Size([0]))

View File

@ -1,42 +1,21 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import shutil
import sys
import tempfile
from typing import Dict
import torch
import torch.distributed as dist
from torch.distributed._shard import sharded_tensor
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
state_dict_hook,
)
from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
EnumerableShardingSpec,
ShardingSpec,
ShardMetadata,
)
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
MyShardedModel1,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.distributed.checkpoint import (
FileSystemReader,
@ -44,6 +23,20 @@ from torch.distributed.checkpoint import (
load_state_dict,
save_state_dict,
)
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
MyShardedModel1,
)
if TEST_WITH_DEV_DBG_ASAN:
@ -122,9 +115,7 @@ class TestDistributedStateDictSaveLoad(TestCase):
state_dict_to_load_to = MyTestModule().state_dict()
with self.assertRaises(AssertionError):
assert_state_dict_equal(
self, state_dict_to_load_to, state_dict_to_save
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
@ -134,9 +125,7 @@ class TestDistributedStateDictSaveLoad(TestCase):
no_dist=True,
)
assert_state_dict_equal(
self, state_dict_to_load_to, state_dict_to_save
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
with tempfile.TemporaryDirectory() as path:
state_dict_to_save = MyTestModule().state_dict()
@ -151,9 +140,7 @@ class TestDistributedStateDictSaveLoad(TestCase):
state_dict_to_load_to = MyTestModule().state_dict()
with self.assertRaises(AssertionError):
assert_state_dict_equal(
self, state_dict_to_load_to, state_dict_to_save
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
@ -163,9 +150,7 @@ class TestDistributedStateDictSaveLoad(TestCase):
no_dist=True,
)
assert_state_dict_equal(
self, state_dict_to_load_to, state_dict_to_save
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@ -212,15 +197,11 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
dist.barrier()
with self.assertRaises(AssertionError):
assert_state_dict_equal(
self, state_dict_to_load_to, state_dict_to_save
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Test load.
fs_reader = FileSystemReader(path=path)
load_state_dict(
state_dict=state_dict_to_load_to, storage_reader=fs_reader
)
load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
dist.barrier()
@ -238,9 +219,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
res = (
torch.zeros(tensor.shape, device="cuda:0")
if dist.get_rank() == 0
else None
torch.zeros(tensor.shape, device="cuda:0") if dist.get_rank() == 0 else None
)
tensor.gather(out=res)
return res
@ -335,9 +314,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
state_dict_to_save = model_to_save.state_dict()
fs_writer = FileSystemWriter(path=path)
save_state_dict(
state_dict=state_dict_to_save, storage_writer=fs_writer
)
save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
dist.barrier()
@ -404,9 +381,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
fs_reader = FileSystemReader(path=path)
load_state_dict(
state_dict=state_dict_to_load_to, storage_reader=fs_reader
)
load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
# We can't use torch.allclose since each ST has a different sharding spec
store_tensor = self.load_tensor(model_to_save.sharded_tensor)
@ -516,9 +491,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
f"save-spec {save_spec} load-spec {load_spec}",
)
self.assertTrue(
torch.allclose(
save_dict["replicated"], load_dict_replicated
),
torch.allclose(save_dict["replicated"], load_dict_replicated),
f"save-spec {save_spec} load-spec {load_spec}",
)

View File

@ -1,23 +1,23 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
import torch.distributed.checkpoint as dist_cp
import torch.distributed as dist
from torch.distributed.checkpoint.default_planner import (
DefaultSavePlanner,
DefaultLoadPlanner,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir

View File

@ -9,23 +9,12 @@ import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint._fsspec_filesystem import (
FsspecReader,
FsspecWriter,
)
from torch.distributed.checkpoint.optimizer import (
load_sharded_optimizer_state_dict,
)
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
@ -182,9 +171,7 @@ class TestFSSpecWithDist(ShardedTensorTestBase):
return list(iter(opt.state.values()))[idx]
# Adam lazily creates its state
self.assertEqual(
opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"]
)
self.assertEqual(opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"])
self.assertEqual(
opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
)

View File

@ -1,11 +1,11 @@
# Owner(s): ["oncall: distributed"]
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.distributed.checkpoint._nested_dict import (
flatten_state_dict,
unflatten_state_dict,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestFlattening(TestCase):

View File

@ -3,43 +3,45 @@
import sys
import torch
from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardMetadata,
ShardedTensor,
ShardedTensorMetadata,
ShardMetadata,
)
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
from torch.testing._internal.common_utils import (
TestCase,
TEST_WITH_DEV_DBG_ASAN,
run_tests,
from torch.distributed.checkpoint.default_planner import (
_create_default_local_metadata,
create_default_global_save_plan,
create_default_local_load_plan,
create_default_local_save_plan,
)
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
MetadataIndex,
TensorStorageMetadata,
ChunkStorageMetadata,
)
from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
from torch.distributed.checkpoint.planner_helpers import (
create_read_items_for_chunk_list,
)
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
from torch.testing._internal.distributed.distributed_utils import (
with_dist,
with_fake_comms,
with_dist
)
from torch.distributed.checkpoint.default_planner import (
create_default_global_save_plan,
create_default_local_save_plan,
create_default_local_load_plan,
_create_default_local_metadata
)
from torch.distributed.checkpoint.planner_helpers import create_read_items_for_chunk_list
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
if TEST_WITH_DEV_DBG_ASAN:
print(
@ -48,30 +50,34 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
shards_metadata = []
local_shards = []
for idx in range(0, world_size * shards_per_rank):
shard_rank = idx // shards_per_rank
shard_md = ShardMetadata(shard_offsets=[idx * shard_size], shard_sizes=[shard_size], placement=f"rank:{shard_rank}/cpu")
shard_md = ShardMetadata(
shard_offsets=[idx * shard_size],
shard_sizes=[shard_size],
placement=f"rank:{shard_rank}/cpu",
)
shards_metadata.append(shard_md)
if shard_rank == rank:
shard = Shard.from_tensor_and_offsets(
torch.rand(*shard_md.shard_sizes),
shard_offsets=shard_md.shard_offsets,
rank=rank
rank=rank,
)
local_shards.append(shard)
sharded_tensor_md = ShardedTensorMetadata(
shards_metadata=shards_metadata,
size=torch.Size([shard_size * len(shards_metadata)]),
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=sharded_tensor_md
local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
)
@ -81,18 +87,17 @@ class TestSavePlan(TestCase):
tensor = torch.rand(10)
val = [1, 2, 3]
st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1)
state_dict = {
"tensor": tensor,
"value": val,
"st": st
}
state_dict = {"tensor": tensor, "value": val, "st": st}
plan = create_default_local_save_plan(state_dict, False)
self.assertEqual(2, len(plan.items))
wi = plan.items[0]
self.assertEqual(wi.index, MetadataIndex("tensor", [0]))
self.assertEqual(wi.type, WriteItemType.TENSOR)
self.assertEqual(wi.tensor_data.size, tensor.size())
self.assertEqual(wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
self.assertEqual(
wi.tensor_data.properties,
TensorProperties.create_from_tensor(torch.zeros(1)),
)
self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0]))
self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10]))
@ -100,7 +105,10 @@ class TestSavePlan(TestCase):
self.assertEqual(st_wi.index, MetadataIndex("st", [8]))
self.assertEqual(st_wi.type, WriteItemType.SHARD)
self.assertEqual(st_wi.tensor_data.size, st.size())
self.assertEqual(st_wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
self.assertEqual(
st_wi.tensor_data.properties,
TensorProperties.create_from_tensor(torch.zeros(1)),
)
self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8]))
self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8]))
@ -111,7 +119,10 @@ class TestSavePlan(TestCase):
tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR)
self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0]))
self.assertEqual(tensor_wi.tensor_data.size, tensor.size())
self.assertEqual(tensor_wi.tensor_data.properties, TensorProperties.create_from_tensor(tensor))
self.assertEqual(
tensor_wi.tensor_data.properties,
TensorProperties.create_from_tensor(tensor),
)
self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0]))
self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10]))
@ -125,11 +136,7 @@ class TestSavePlan(TestCase):
tensor = torch.rand(10)
val = [1, 2, 3]
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
state_dict = {
"tensor": tensor,
"value": val,
"st": st
}
state_dict = {"tensor": tensor, "value": val, "st": st}
return create_default_local_save_plan(state_dict, rank == 0)
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
@ -150,11 +157,15 @@ class TestSavePlan(TestCase):
else:
self.assertTrue(isinstance(item_md, TensorStorageMetadata))
self.assertEqual(item_md.size, old_item.tensor_data.size)
self.assertEqual(item_md.properties, old_item.tensor_data.properties)
self.assertEqual(
item_md.properties, old_item.tensor_data.properties
)
self.assertIsNotNone(new_item.index.index)
# Make sure the hint is correct
self.assertEqual(item_md.chunks[new_item.index.index], old_item.tensor_data.chunk)
self.assertEqual(
item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
)
def test_local_load_plan(self):
def create_state_dict(rank):
@ -162,11 +173,7 @@ class TestSavePlan(TestCase):
tensor = torch.rand(10)
val = [1, 2, 3]
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
return {
"tensor": tensor,
"value": val,
"st": st
}
return {"tensor": tensor, "value": val, "st": st}
state_dict = create_state_dict(1)
metadata = _create_default_local_metadata(state_dict)
@ -175,7 +182,9 @@ class TestSavePlan(TestCase):
# This will create 3 entries
self.assertEqual(3, len(load_plan.items))
st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st")
tensor_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "tensor")
tensor_item = next(
ri for ri in load_plan.items if ri.dest_index.fqn == "tensor"
)
bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value")
self.assertEqual(st_item.type, LoadItemType.TENSOR)
@ -208,7 +217,6 @@ class TestSavePlan(TestCase):
)
}
# Rank 1 has a 16 bytes shard from [16, 32[
world8_state_dict = create_state_dict(rank=1, world_size=8)
world8_metadata = _create_default_local_metadata(world8_state_dict)
@ -221,8 +229,12 @@ class TestSavePlan(TestCase):
# Each 4-world shard has 32 elements, so it needs to load 2 shards
load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata)
self.assertEqual(2, len(load_plan.items))
low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16]))
low_ri = next(
ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
)
high_ri = next(
ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16])
)
self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32]))
self.assertEqual(low_ri.storage_offsets, torch.Size([0]))
@ -258,6 +270,7 @@ class TestSavePlan(TestCase):
shard_size=120 // world_size,
)
}
# rank 1 has a 30 bytes shard from [30, 60[
world4_state_dict = create_state_dict(rank=1, world_size=4)
world4_metadata = _create_default_local_metadata(world4_state_dict)
@ -268,9 +281,13 @@ class TestSavePlan(TestCase):
load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata)
self.assertEqual(2, len(load_plan.items))
# this is [30, 60] to load [40, 60]
low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
low_ri = next(
ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
)
# this is [60, 90] to load [60, 80]
high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20]))
high_ri = next(
ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20])
)
self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30]))
self.assertEqual(low_ri.storage_offsets, torch.Size([10]))
@ -284,27 +301,19 @@ class TestSavePlan(TestCase):
self.assertEqual(high_ri.dest_offsets, torch.Size([20]))
self.assertEqual(high_ri.lengths, torch.Size([20]))
class TestPlannerHelpers(TestCase):
def test_create_read_item_from_chunks(self):
tensor_md = TensorStorageMetadata(
properties=TensorProperties.create_from_tensor(torch.empty([16])),
size=torch.Size([16]),
chunks=[
ChunkStorageMetadata(
offsets=torch.Size([0]),
sizes=torch.Size([8])
),
ChunkStorageMetadata(
offsets=torch.Size([8]),
sizes=torch.Size([8])
)
]
ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])),
ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])),
],
)
chunk = ChunkStorageMetadata(
offsets=torch.Size([4]),
sizes=torch.Size([7])
)
chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7]))
read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk])
self.assertEqual(2, len(read_items))
@ -316,7 +325,6 @@ class TestPlannerHelpers(TestCase):
self.assertEqual(torch.Size([4]), read_items[0].lengths)
self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index)
self.assertEqual(torch.Size([4]), read_items[1].dest_offsets)
@ -325,5 +333,6 @@ class TestPlannerHelpers(TestCase):
self.assertEqual(torch.Size([3]), read_items[1].lengths)
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
from collections import OrderedDict
import torch
import torch.distributed.checkpoint._traverse as _traverse
@ -95,9 +96,7 @@ class TestTraverse(TestCase):
self.assertEqual(data[("key0", "key2")], torch.tensor([1]))
def test_traverse_doesnt_ignore_intermediate_collections(self) -> None:
state_dict: STATE_DICT_TYPE = {
"key0": [{"key1": {"key2": torch.tensor([1])}}]
}
state_dict: STATE_DICT_TYPE = {"key0": [{"key1": {"key2": torch.tensor([1])}}]}
data = {}

View File

@ -6,22 +6,20 @@ import torch
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardMetadata,
ShardedTensor,
ShardedTensorMetadata,
ShardMetadata,
)
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.testing._internal.common_utils import (
TestCase,
TEST_WITH_DEV_DBG_ASAN,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.testing._internal.distributed.distributed_utils import (
with_fake_comms
)
from torch.testing._internal.distributed.distributed_utils import with_fake_comms
if TEST_WITH_DEV_DBG_ASAN:
print(
@ -30,30 +28,32 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
def create_sharded_tensor(rank, world_size, shards_per_rank):
shards_metadata = []
local_shards = []
for idx in range(0, world_size * shards_per_rank):
shard_rank = idx // shards_per_rank
shard_md = ShardMetadata(shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu")
shard_md = ShardMetadata(
shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
)
shards_metadata.append(shard_md)
if shard_rank == rank:
shard = Shard.from_tensor_and_offsets(
torch.rand(*shard_md.shard_sizes),
shard_offsets=shard_md.shard_offsets,
rank=rank
rank=rank,
)
local_shards.append(shard)
sharded_tensor_md = ShardedTensorMetadata(
shards_metadata=shards_metadata,
size=torch.Size([8 * len(shards_metadata)]),
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=sharded_tensor_md
local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
)

View File

@ -1,23 +1,15 @@
from .api import CheckpointException
from .checkpointer import Checkpointer
from .default_planner import DefaultLoadPlanner, DefaultSavePlanner
from .filesystem import FileSystemCheckpointer, FileSystemReader, FileSystemWriter
from .metadata import (
TensorStorageMetadata,
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
TensorStorageMetadata,
)
from .state_dict_loader import load_state_dict, load
from .state_dict_saver import save_state_dict, save
from .storage import StorageReader, StorageWriter
from .checkpointer import Checkpointer
from .filesystem import FileSystemReader, FileSystemWriter, FileSystemCheckpointer
from .api import CheckpointException
from .planner import (
SavePlanner,
LoadPlanner,
SavePlan,
LoadPlan,
ReadItem,
WriteItem,
)
from .default_planner import DefaultSavePlanner, DefaultLoadPlanner
from .optimizer import load_sharded_optimizer_state_dict
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
from .state_dict_loader import load, load_state_dict
from .state_dict_saver import save, save_state_dict
from .storage import StorageReader, StorageWriter

View File

@ -23,8 +23,10 @@ def init_logger() -> logging.Logger:
logger.propagate = False
return logger
logger = init_logger()
# TODO add docstring for dedup_tensors
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
all_plans = list(all_plans)
@ -51,8 +53,6 @@ def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
for write_item in all_plans[plan_idx].items
if write_item.index not in key_set
]
all_plans[plan_idx] = dataclasses.replace(
all_plans[plan_idx], items=new_items
)
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
return all_plans

View File

@ -9,20 +9,18 @@ import pickle
import queue
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, cast, Dict, List, Optional, Union
import fsspec
import torch
from fsspec import AbstractFileSystem
from fsspec.core import url_to_fs
import torch
from torch import Tensor
from torch._utils import _get_device_module
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
from torch.distributed.checkpoint.planner import (
LoadItemType,
LoadPlan,
@ -145,10 +143,7 @@ class _OverlappingCpuLoader(_TensorLoader):
def _refill(self):
with self.device_module.stream(self.stream):
while (
not self._done
and self.in_flight_data < self.inflight_threshhold
):
while not self._done and self.in_flight_data < self.inflight_threshhold:
_, obj = self.items[self.idx]
self.idx += 1
tensor = self.resolve_fun(obj).detach()
@ -206,9 +201,7 @@ def _item_size(item: WriteItem) -> int:
return size * torch._utils._element_size(dtype)
def _split_by_size_and_type(
bins: int, items: List[WriteItem]
) -> List[List[WriteItem]]:
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
if bins == 1:
return [items]
@ -276,16 +269,12 @@ def _write_files_from_queue(
planner.resolve_data,
)
tensor_w = [
wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
]
tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
for write_item in tensor_w:
loader.add(_item_size(write_item), write_item)
loader.start_loading()
bytes_w = [
wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
]
bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
write_results = []
with fs.transaction:
@ -351,9 +340,7 @@ class FsspecWriter(StorageWriter):
self.fs.makedirs(self.path, exist_ok=True)
return plan
def prepare_global_plan(
self, global_plan: List[SavePlan]
) -> List[SavePlan]:
def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
for i, plan in enumerate(global_plan)
@ -376,9 +363,7 @@ class FsspecWriter(StorageWriter):
file_queue: queue.Queue = queue.Queue()
if self.single_file_per_rank:
for bucket in _split_by_size_and_type(
self.thread_count, plan.items
):
for bucket in _split_by_size_and_type(self.thread_count, plan.items):
file_name = gen_file()
file_path = os.path.join(self.path, file_name)
file_queue.put((file_path, file_name, bucket))
@ -427,9 +412,7 @@ class FsspecWriter(StorageWriter):
fut.set_result(res)
return fut
def finish(
self, metadata: Metadata, results: List[List[WriteResult]]
) -> None:
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
storage_md = dict()
for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
@ -495,16 +478,12 @@ class FsspecReader(StorageReader):
with fsspec.open(metadata_path, "rb") as metadata_file:
return pickle.load(metadata_file)
def set_up_storage_reader(
self, metadata: Metadata, is_coordinator: bool
) -> None:
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
self.storage_data = metadata.storage_data
assert self.storage_data is not None
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan
def prepare_global_plan(
self, global_plan: List[LoadPlan]
) -> List[LoadPlan]:
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
return global_plan

View File

@ -1,16 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict, Tuple
from torch.distributed.checkpoint.metadata import (
STATE_DICT_TYPE,
)
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from ._traverse import (
traverse_state_dict,
set_element,
OBJ_PATH,
STATE_DICT_ITEM,
)
from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
"""
TODO:

View File

@ -3,29 +3,12 @@
import copy
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.remote_device import _remote_device
from torch.distributed.checkpoint.metadata import (
STATE_DICT_TYPE,
)
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardMetadata,
ShardedTensor,
)
from torch.distributed._shard.sharded_tensor.metadata import (
ShardedTensorMetadata,
)
from ._traverse import (
OBJ_PATH,
traverse_state_dict,
set_element,
STATE_DICT_ITEM,
)
from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
from .utils import _element_wise_add, _normalize_device_info
@ -62,9 +45,7 @@ def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
return
if len(inner_st.local_shards()) != 1:
raise ValueError(
"Cannot handle inner tensor with more than 1 shard"
)
raise ValueError("Cannot handle inner tensor with more than 1 shard")
inner_shard = inner_st.local_shards()[0]
local_shards = [

View File

@ -1,8 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from typing import (
Callable,
cast,
Collection,
List,
Mapping,
@ -11,13 +10,12 @@ from typing import (
Tuple,
TypeVar,
Union,
cast,
)
from torch.distributed.checkpoint.metadata import (
STATE_DICT_TYPE,
)
import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
PATH_ITEM = Union[str, int]
OBJ_PATH = Tuple[PATH_ITEM, ...]
@ -47,6 +45,7 @@ def traverse_state_dict(
By default, all collections with at least one ``torch.Tensor`` element are traversed.
Visitor takes a path argument that is a tuple of the keys used to reach it.
"""
# a value is terminal if it has no other containers values inside it
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
values: Collection[STATE_DICT_ITEM]

View File

@ -1,5 +1,5 @@
from typing import Dict, Tuple, Any
import traceback as tb
from typing import Any, Dict, Tuple
WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
@ -15,9 +15,7 @@ def _is_wrapped_exception(obj: Any) -> bool:
return False
if len(obj) != 2:
return False
return isinstance(obj[0], BaseException) and isinstance(
obj[1], tb.StackSummary
)
return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
class CheckpointException(BaseException):

View File

@ -6,50 +6,42 @@ import logging
import operator
from collections import ChainMap
from functools import reduce
from typing import List, Tuple, Dict, Any, Union, cast
from typing import Any, cast, Dict, List, Tuple, Union
import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.planner import (
SavePlanner,
LoadPlanner,
SavePlan,
LoadPlan,
ReadItem,
WriteItem,
WriteItemType,
)
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
TensorStorageMetadata,
MetadataIndex,
Metadata,
STATE_DICT_TYPE,
STORAGE_TYPES,
)
from torch.distributed.checkpoint.planner_helpers import (
_create_read_items,
_create_write_items,
_create_default_metadata_only_plan,
)
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
from torch.distributed.checkpoint._nested_dict import (
FLATTEN_MAPPING,
flatten_state_dict,
)
from torch.distributed.checkpoint._sharded_tensor_utils import (
_flatten_sharded_tensors,
)
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
STATE_DICT_TYPE,
STORAGE_TYPES,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
LoadPlan,
LoadPlanner,
ReadItem,
SavePlan,
SavePlanner,
WriteItem,
WriteItemType,
)
from torch.distributed.checkpoint.planner_helpers import (
_create_default_metadata_only_plan,
_create_read_items,
_create_write_items,
)
from torch.distributed.checkpoint.utils import find_state_dict_object
logger: logging.Logger = logging.getLogger(__name__)
@ -79,9 +71,7 @@ class DefaultSavePlanner(SavePlanner):
self.dedup_replicated_tensors = dedup_replicated_tensors
self.mappings = {}
def set_up_planner(
self, state_dict: STATE_DICT_TYPE, is_coordinator: bool
) -> None:
def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)
if self.flatten_sharded_tensors:
@ -90,9 +80,7 @@ class DefaultSavePlanner(SavePlanner):
self.is_coordinator = is_coordinator
def create_local_plan(self) -> SavePlan:
plan = create_default_local_save_plan(
self.state_dict, self.is_coordinator
)
plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
if self.flatten_state_dict:
plan = dataclasses.replace(plan, planner_data=self.mappings)
self.plan = plan
@ -114,9 +102,7 @@ class DefaultSavePlanner(SavePlanner):
# )
planner_data_dict = [p.planner_data for p in global_plan]
merged_mappings = dict(ChainMap(*planner_data_dict))
metadata = dataclasses.replace(
metadata, planner_data=merged_mappings
)
metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
if not _validate_global_plan(global_plan, metadata):
raise ValueError("Failed to validate global plan")
@ -130,9 +116,7 @@ class DefaultSavePlanner(SavePlanner):
self.plan = new_plan
return new_plan
def resolve_data(
self, write_item: WriteItem
) -> Union[torch.Tensor, io.BytesIO]:
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
object = self.lookup_object(write_item.index)
return self.transform_object(write_item, object)
@ -222,9 +206,7 @@ class DefaultLoadPlanner(LoadPlanner):
def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
"""Extension from the planner interface to make it easy to extend the default planner."""
return narrow_tensor_by_index(
tensor, read_item.dest_offsets, read_item.lengths
)
return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
def create_default_local_load_plan(
@ -353,9 +335,7 @@ def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
return md
def _check_box_overlap(
box0: ChunkStorageMetadata, box1: ChunkStorageMetadata
) -> bool:
def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
"""Check if two boxes overlap. Tuples are (offset, lengths)."""
# For each dim of each shard, check if one shard resides on the other
# end of second shard with respect to that dim. As an example for a 2D
@ -385,9 +365,7 @@ def _check_box_bounds(
return True
def _validate_global_plan(
global_plan: List[SavePlan], metadata: Metadata
) -> bool:
def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
all_good = True
for key, value in metadata.state_dict_metadata.items():
if isinstance(value, BytesStorageMetadata):
@ -402,7 +380,10 @@ def _validate_global_plan(
"""
key:%s has out of bounds chunk:
tensor-size:%s chunk: %s
""", key, value.size, chunk0
""",
key,
value.size,
chunk0,
)
all_good = False
chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
@ -422,7 +403,10 @@ def _validate_global_plan(
"""
key:%s invalid fill tensor-volume:
%s chunks-volume: %s
""", key, tensor_volume, chunks_volume
""",
key,
tensor_volume,
chunks_volume,
)
all_good = False

View File

@ -14,12 +14,10 @@ import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch.multiprocessing as mp
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.checkpoint.optimizer import (
load_sharded_optimizer_state_dict,
)
CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint"

View File

@ -1,53 +1,38 @@
from abc import ABC, abstractmethod
import queue
import threading
import collections
from dataclasses import dataclass
import os
import dataclasses
import io
import os
import pickle
from typing import Optional, List, Union, Dict, cast
import queue
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import cast, Dict, List, Optional, Union
import torch
import torch.distributed as dist
from torch import Tensor
from torch._utils import _get_device_module
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint.checkpointer import Checkpointer
from torch.futures import Future
from pathlib import Path
from .metadata import (
Metadata,
MetadataIndex,
)
from .storage import (
StorageReader,
StorageWriter,
WriteResult,
)
from .metadata import Metadata, MetadataIndex
from .planner import (
LoadItemType,
LoadPlanner,
LoadPlan,
LoadPlanner,
ReadItem,
SavePlan,
SavePlanner,
ReadItem,
WriteItem,
WriteItemType,
)
from .storage import StorageReader, StorageWriter, WriteResult
from .utils import _create_file_view
import torch.distributed as dist
from torch.distributed.checkpoint.checkpointer import Checkpointer
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch._utils import _get_device_module
__all__ = [
"FileSystemWriter",
"FileSystemReader",
"FileSystemCheckpointer"
]
__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystemCheckpointer"]
@dataclass
@ -150,10 +135,7 @@ class _OverlappingCpuLoader(_TensorLoader):
def _refill(self):
with self.device_module.stream(self.stream):
while (
not self._done
and self.in_flight_data < self.inflight_threshhold
):
while not self._done and self.in_flight_data < self.inflight_threshhold:
_, obj = self.items[self.idx]
self.idx += 1
tensor = self.resolve_fun(obj).detach()
@ -211,9 +193,7 @@ def _item_size(item: WriteItem) -> int:
return size * torch._utils._element_size(dtype)
def _split_by_size_and_type(
bins, items: List[WriteItem]
) -> List[List[WriteItem]]:
def _split_by_size_and_type(bins, items: List[WriteItem]) -> List[List[WriteItem]]:
if bins == 1:
return [items]
@ -276,16 +256,12 @@ def _write_files_from_queue(
planner.resolve_data,
)
tensor_w = [
wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
]
tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
for write_item in tensor_w:
loader.add(_item_size(write_item), write_item)
loader.start_loading()
bytes_w = [
wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
]
bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
write_results = []
with file_name.open("wb") as stream:
@ -358,9 +334,7 @@ class FileSystemWriter(StorageWriter):
self.path.mkdir(parents=True, exist_ok=True)
return plan
def prepare_global_plan(
self, global_plan: List[SavePlan]
) -> List[SavePlan]:
def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
for i, plan in enumerate(global_plan)
@ -383,9 +357,7 @@ class FileSystemWriter(StorageWriter):
file_queue: queue.Queue = queue.Queue()
if self.single_file_per_rank:
for bucket in _split_by_size_and_type(
self.thread_count, plan.items
):
for bucket in _split_by_size_and_type(self.thread_count, plan.items):
file_name = gen_file()
file_queue.put((self.path / file_name, file_name, bucket))
else:
@ -432,9 +404,7 @@ class FileSystemWriter(StorageWriter):
fut.set_result(res)
return fut
def finish(
self, metadata: Metadata, results: List[List[WriteResult]]
) -> None:
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
storage_md = dict()
for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
@ -507,11 +477,10 @@ class FileSystemReader(StorageReader):
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan
def prepare_global_plan(
self, global_plan: List[LoadPlan]
) -> List[LoadPlan]:
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
return global_plan
class FileSystemCheckpointer(Checkpointer):
"""An implementation of :py:class:`torch.distributed.checkpoint.checkpointer.Checkpointer`
for the file system. Wraps the creation and usage of ``FileSystemWriter`` and ``FileSystemReader``.
@ -547,11 +516,7 @@ class FileSystemCheckpointer(Checkpointer):
"""
storage_writer = FileSystemWriter(
path,
single_file_per_rank,
sync_files,
thread_count,
per_thread_copy_ahead
path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
)
storage_reader = FileSystemReader(path)
@ -562,5 +527,5 @@ class FileSystemCheckpointer(Checkpointer):
coordinator_rank=coordinator_rank,
no_dist=no_dist,
load_planner=load_planner,
save_planner=save_planner
save_planner=save_planner,
)

View File

@ -1,12 +1,10 @@
from dataclasses import dataclass, field
from typing import Dict, List, Union, Optional, Sequence, Any
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed.checkpoint.stateful import StatefulT
from typing import Any, Dict, List, Optional, Sequence, Union
import torch
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed.checkpoint.stateful import StatefulT
__all__ = [
"ChunkStorageMetadata",

View File

@ -1,49 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
from typing import Dict, List, Optional, Sequence, Tuple, Union, cast
from torch.distributed.checkpoint.planner import LoadPlan
from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import (
ChunkShardingSpec,
)
import torch.distributed.checkpoint as dist_cp
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
STATE_DICT_TYPE,
TensorStorageMetadata,
ChunkStorageMetadata,
)
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
from torch.distributed.checkpoint.planner_helpers import (
create_read_items_for_chunk_list,
_create_read_items,
create_read_items_for_chunk_list,
)
from torch.distributed.remote_device import _remote_device
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
)
from torch.distributed.checkpoint.planner import LoadPlanner
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
from torch.distributed.checkpoint.storage import StorageReader
from torch.distributed.checkpoint.utils import (
_element_wise_add,
_element_wise_sub,
_normalize_device_info
_normalize_device_info,
)
from torch._utils import _get_device_module
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.remote_device import _remote_device
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
@ -59,7 +51,9 @@ def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
return "cpu"
device_module = _get_device_module(device_type)
if device_module.is_available():
return _normalize_device_info(device_type, global_rank % device_module.device_count())
return _normalize_device_info(
device_type, global_rank % device_module.device_count()
)
return "cpu"
@ -90,18 +84,17 @@ def _is_nested_tensor(val: torch.Tensor) -> bool:
if type(val.local_shards()[0].tensor) is ShardedTensor:
return True
if type(val.local_shards()[0].tensor) is DTensor:
raise ValueError(
"Cannot handle DTensor nested insided ShardedTensor"
)
raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
elif type(val) is DTensor and (
type(val._local_tensor) is DTensor
or type(val._local_tensor) is ShardedTensor
type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
):
raise ValueError("Cannot handle nested DTensor")
return False
def _alloc_tensor(props: TensorProperties, size: Sequence[int], device_type: str = "cuda") -> torch.Tensor:
def _alloc_tensor(
props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
) -> torch.Tensor:
return torch.empty(
size=size,
dtype=props.dtype,
@ -181,9 +174,7 @@ class _ReaderWithOffset(DefaultLoadPlanner):
local_chunks = [
ChunkStorageMetadata(
offsets=torch.Size(
_element_wise_add(
original_shard.metadata.shard_offsets, offset
)
_element_wise_add(original_shard.metadata.shard_offsets, offset)
),
sizes=torch.Size(original_shard.metadata.shard_sizes),
)
@ -196,9 +187,7 @@ class _ReaderWithOffset(DefaultLoadPlanner):
# TODO: we should change _create_sharded_read_items to have more ergonomic API
for ri in reqs:
assert ri.dest_index.offset is not None
original_offset = _element_wise_sub(
ri.dest_index.offset, offset
)
original_offset = _element_wise_sub(ri.dest_index.offset, offset)
original_index = dataclasses.replace(
ri.dest_index, offset=torch.Size(original_offset)
)
@ -214,7 +203,7 @@ class _ReaderWithOffset(DefaultLoadPlanner):
def load_sharded_optimizer_state_dict(
model_state_dict: STATE_DICT_TYPE,
optimizer_key: str,
storage_reader: dist_cp.StorageReader,
storage_reader: StorageReader,
planner: Optional[LoadPlanner] = None,
) -> STATE_DICT_TYPE:
"""
@ -273,7 +262,9 @@ def load_sharded_optimizer_state_dict(
if dp_pg is None:
placements = []
for i in range(dist.get_world_size()):
device_info = _normalize_device_info(dp_pg_device_type, i % device_module.device_count())
device_info = _normalize_device_info(
dp_pg_device_type, i % device_module.device_count()
)
placements.append(f"rank:{i}/{device_info}")
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type]
else:
@ -294,7 +285,9 @@ def load_sharded_optimizer_state_dict(
# value: TensorStorageMetadata
if value.size.numel() == 1:
state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
state_dict[key] = _alloc_tensor(
value.properties, value.size, dp_pg_device_type
)
elif dp_pg is None:
state_dict[key] = _create_chunk_sharded_tensor(
_alloc_tensor(value.properties, value.size, dp_pg_device_type),
@ -313,10 +306,7 @@ def load_sharded_optimizer_state_dict(
local_shards = []
current_rank = dist.get_rank(dp_pg)
for shard_md in st_md.shards_metadata:
if (
cast(_remote_device, shard_md.placement).rank()
!= current_rank
):
if cast(_remote_device, shard_md.placement).rank() != current_rank:
continue
local_shards.append(
Shard(
@ -331,18 +321,13 @@ def load_sharded_optimizer_state_dict(
local_shards, st_md, process_group=dp_pg
)
if (
spec_key in layout_specs
and layout_specs[spec_key][0] is not None
):
fqn_to_offset[key] = cast(
Sequence[int], layout_specs[spec_key][0]
)
if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
state_dict[key] = st
# Whether we unflatten before or after doesn't matter
dist_cp.load_state_dict(
load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
# FIXME the type of planner is wrong in load_state_dict

View File

@ -1,19 +1,13 @@
import abc
from dataclasses import dataclass
import io
from typing import List, Tuple, Any, Union, Optional
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, List, Optional, Tuple, Union
from enum import Enum, auto
import torch
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from .metadata import (
ChunkStorageMetadata,
MetadataIndex,
Metadata,
STATE_DICT_TYPE,
)
from .metadata import ChunkStorageMetadata, Metadata, MetadataIndex, STATE_DICT_TYPE
__all__ = [
@ -223,9 +217,7 @@ class SavePlanner(abc.ABC):
pass
@abc.abstractmethod
def resolve_data(
self, write_item: WriteItem
) -> Union[torch.Tensor, io.BytesIO]:
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
"""
Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.

View File

@ -1,7 +1,6 @@
from typing import Any, List
import torch
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
@ -16,7 +15,6 @@ from .metadata import (
STORAGE_TYPES,
TensorStorageMetadata,
)
from .planner import (
LoadItemType,
ReadItem,
@ -25,7 +23,6 @@ from .planner import (
WriteItem,
WriteItemType,
)
from .resharding import (
_check_shard_metadata_pair_overlap,
_shards_get_overlap_region_wrt_saved_tensor,

View File

@ -1,12 +1,13 @@
from typing import List, Tuple
from torch.distributed.checkpoint.metadata import (
ChunkStorageMetadata
)
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
__all__: List[str] = []
def _check_shard_metadata_pair_overlap(shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata):
def _check_shard_metadata_pair_overlap(
shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
):
"""Check if two shards overlap."""
# For each dim of each shard, check if one shard resides on the other
# end of second shard with respect to that dim. As an example for a 2D
@ -21,6 +22,7 @@ def _check_shard_metadata_pair_overlap(shard1: ChunkStorageMetadata, shard2: Chu
return True
def _shards_get_overlap_region_wrt_saved_tensor(
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
) -> List[Tuple[int, int, int, int]]:
@ -56,9 +58,7 @@ def _shards_get_overlap_region_wrt_saved_tensor(
if saved_shard_offset > current_shard_offset:
offset_for_saved_tensor = 0
offset_for_current_tensor = (
saved_shard_offset - current_shard_offset
)
offset_for_current_tensor = saved_shard_offset - current_shard_offset
else:
offset_for_saved_tensor = current_shard_offset - saved_shard_offset
offset_for_current_tensor = 0

View File

@ -1,20 +1,18 @@
from typing import Any, Dict, Optional
import warnings
from typing import Any, Dict, Optional
import torch
import torch.distributed as dist
from torch.distributed.checkpoint.stateful import Stateful
from .storage import (
StorageReader,
)
from .planner import LoadPlanner
from .default_planner import DefaultLoadPlanner
from .utils import _DistWrapper, _all_gather_keys
from .planner import LoadPlanner
from .storage import StorageReader
from .utils import _all_gather_keys, _DistWrapper
__all__ = ["load_state_dict", "load"]
def load_state_dict(
state_dict: Dict[str, Any],
storage_reader: StorageReader,
@ -28,7 +26,10 @@ def load_state_dict(
"'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead."
)
# TODO: test returning `load` here instead.
return _load_state_dict(state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner)
return _load_state_dict(
state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner
)
def load(
state_dict: Dict[str, Any],
@ -124,7 +125,9 @@ def load(
elem = state_dict[key]
statetful_sd[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
_load_state_dict(statetful_sd, storage_reader, process_group, coordinator_rank, no_dist, planner)
_load_state_dict(
statetful_sd, storage_reader, process_group, coordinator_rank, no_dist, planner
)
for key in keys:
if key not in state_dict:
continue
@ -133,6 +136,7 @@ def load(
elem.load_state_dict(statetful_sd[key])
state_dict[key] = elem
def _load_state_dict(
state_dict: Dict[str, Any],
storage_reader: StorageReader,
@ -141,7 +145,6 @@ def _load_state_dict(
no_dist: bool = False,
planner: Optional[LoadPlanner] = None,
) -> None:
torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)

View File

@ -1,18 +1,14 @@
from typing import Optional
import warnings
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed.checkpoint.stateful import Stateful
from .planner import SavePlanner
from .default_planner import DefaultSavePlanner
from .storage import (
StorageWriter,
)
from .metadata import Metadata, STATE_DICT_TYPE
from .planner import SavePlanner
from .storage import StorageWriter
from .utils import _DistWrapper
__all__ = ["save_state_dict", "save"]
@ -32,7 +28,10 @@ def save_state_dict(
)
# TODO: test returning `save` here instead.
return _save_state_dict(state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner)
return _save_state_dict(
state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner
)
def save(
state_dict: STATE_DICT_TYPE,
@ -108,7 +107,9 @@ def save(
dumpable_state_dict = {}
for key, elem in state_dict.items():
dumpable_state_dict[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
dumpable_state_dict[key] = (
elem.state_dict() if isinstance(elem, Stateful) else elem
)
return _save_state_dict(
dumpable_state_dict,
@ -116,9 +117,10 @@ def save(
process_group,
coordinator_rank,
no_dist,
planner
planner,
)
def _save_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
@ -127,7 +129,6 @@ def _save_state_dict(
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
) -> Metadata:
torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
@ -149,9 +150,7 @@ def _save_state_dict(
nonlocal global_metatadata
assert planner is not None
all_local_plans, global_metatadata = planner.create_global_plan(
all_local_plans
)
all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans

View File

@ -1,20 +1,11 @@
import abc
from dataclasses import dataclass
from typing import List, Any
from typing import Any, List
from torch.futures import Future
from .metadata import (
Metadata,
MetadataIndex,
)
from .planner import (
LoadPlan,
SavePlan,
SavePlanner,
LoadPlanner,
)
from .metadata import Metadata, MetadataIndex
from .planner import LoadPlan, LoadPlanner, SavePlan, SavePlanner
__all__ = ["WriteResult", "StorageWriter", "StorageReader"]
@ -115,9 +106,7 @@ class StorageWriter(abc.ABC):
pass
@abc.abstractmethod
def finish(
self, metadata: Metadata, results: List[List[WriteResult]]
) -> None:
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
"""
Write the metadata and marks the current checkpoint as successful.

View File

@ -1,37 +1,21 @@
import os
import io
import itertools
from typing import (
List,
Callable,
Optional,
Union,
TypeVar,
Dict,
Any,
cast,
Sequence
)
import torch.distributed as dist
from .api import (
CheckpointException,
_wrap_exception,
_is_wrapped_exception,
WRAPPED_EXCEPTION,
)
import os
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
import torch
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._tensor import DTensor
from .metadata import (
STATE_DICT_TYPE,
MetadataIndex,
from .api import (
_is_wrapped_exception,
_wrap_exception,
CheckpointException,
WRAPPED_EXCEPTION,
)
from .metadata import MetadataIndex, STATE_DICT_TYPE
__all__ = ["find_tensor_shard", "find_state_dict_object"]
@ -47,6 +31,7 @@ def _get_failure_dict(
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
)
def _all_gather_keys(local_dict: Dict[Any, Any]) -> List[Any]:
"""Gathers all keys, and returns them sorted."""
keys = list(local_dict.keys())
@ -55,6 +40,7 @@ def _all_gather_keys(local_dict: Dict[Any, Any]) -> List[Any]:
dist.all_gather_object(gathered_keys, keys)
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
class _DistWrapper:
"""
This is a wrapper around PG that provides a series of features around object collectives.
@ -123,9 +109,7 @@ class _DistWrapper:
def all_gather_object(self, object: T) -> List[T]:
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
if self.use_dist:
gather_objs = cast(
List[T], [None] * dist.get_world_size(self.group)
)
gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
dist.all_gather_object(
object_list=gather_objs, obj=object, group=self.group
@ -140,9 +124,7 @@ class _DistWrapper:
gather_result = cast(List[T], [None])
dist.scatter_object_list(
scatter_object_output_list=gather_result,
scatter_object_input_list=object_list
if self.is_coordinator
else None,
scatter_object_input_list=object_list if self.is_coordinator else None,
src=self.coordinator_rank,
group=self.group,
)
@ -282,9 +264,7 @@ class _DistWrapper:
try:
result = map_fun()
except BaseException as e:
result = CheckpointException(
step, {self.rank: _wrap_exception(e)}
)
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):
raise final_result
@ -302,22 +282,17 @@ def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
if index.index is not None:
if (
len(shards) > index.index
and torch.Size(shards[index.index].metadata.shard_offsets)
== index.offset
and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
):
return shards[index.index]
for shard in shards:
if torch.Size(shard.metadata.shard_offsets) == index.offset:
return shard
raise ValueError(
f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
)
raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
def find_tensor_shard(
tensor: torch.Tensor, index: MetadataIndex
) -> torch.Tensor:
def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
if isinstance(tensor, DTensor):
return tensor.to_local()
if isinstance(tensor, ShardedTensor):
@ -332,9 +307,7 @@ def find_tensor_shard(
return tensor
def find_state_dict_object(
state_dict: STATE_DICT_TYPE, index: MetadataIndex
) -> Any:
def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
if index.fqn not in state_dict:
raise ValueError(f"Could not find FQN: '{index.fqn}'")
obj = state_dict[index.fqn]