mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: It turns out, the device used as a param in is_pinned is meant to be the accelerator device with the respect to which pinning is expected. Passing 'cpu' always makes the return value false, regardless of whether the actual tensor is a cpu tensor pinned to Cuda. Besides, there is a PR https://github.com/pytorch/pytorch/pull/126376 about to be merged which automatically uses the correct accelerator device which obviates the need for users to pass any kind of explicit device and doesn't create Cuda context for pure cpu tensors. Note, https://www.internalfb.com/intern/test/844425019931542?ref_report_id=0 test is expected to be broken by this diff, but it should be fixed forward by https://github.com/pytorch/pytorch/pull/126376 Test Plan: Sandcastle. Differential Revision: D59283190 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129972 Approved by: https://github.com/LucasLLC
128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
|
|
import torch
|
|
|
|
from torch.distributed._shard.sharded_tensor import (
|
|
Shard,
|
|
ShardedTensor,
|
|
ShardedTensorMetadata,
|
|
ShardMetadata,
|
|
)
|
|
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
|
from torch.distributed.checkpoint.metadata import MetadataIndex
|
|
from torch.distributed.checkpoint.utils import find_state_dict_object
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.distributed.distributed_utils import with_fake_comms
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
def create_sharded_tensor(rank, world_size, shards_per_rank):
|
|
shards_metadata = []
|
|
local_shards = []
|
|
for idx in range(0, world_size * shards_per_rank):
|
|
shard_rank = idx // shards_per_rank
|
|
shard_md = ShardMetadata(
|
|
shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
|
|
)
|
|
shards_metadata.append(shard_md)
|
|
if shard_rank == rank:
|
|
shard = Shard.from_tensor_and_offsets(
|
|
torch.rand(*shard_md.shard_sizes),
|
|
shard_offsets=shard_md.shard_offsets,
|
|
rank=rank,
|
|
)
|
|
local_shards.append(shard)
|
|
|
|
sharded_tensor_md = ShardedTensorMetadata(
|
|
shards_metadata=shards_metadata,
|
|
size=torch.Size([8 * len(shards_metadata)]),
|
|
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
|
|
)
|
|
|
|
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
|
|
)
|
|
|
|
|
|
class TestMedatadaIndex(TestCase):
|
|
def test_init_convert_offset(self):
|
|
a = MetadataIndex("foo", [1, 2])
|
|
b = MetadataIndex("foo", torch.Size([1, 2]))
|
|
self.assertEqual(a, b)
|
|
|
|
def test_index_hint_ignored_on_equals(self):
|
|
a = MetadataIndex("foo")
|
|
b = MetadataIndex("foo", index=99)
|
|
self.assertEqual(a, b)
|
|
|
|
def test_index_hint_ignored_on_hash(self):
|
|
a = MetadataIndex("foo")
|
|
b = MetadataIndex("foo", index=99)
|
|
self.assertEqual(hash(a), hash(b))
|
|
|
|
def test_flat_data(self):
|
|
state_dict = {
|
|
"a": torch.rand(10),
|
|
"b": [1, 2, 3],
|
|
}
|
|
|
|
a = find_state_dict_object(state_dict, MetadataIndex("a"))
|
|
self.assertEqual(a, state_dict["a"])
|
|
a = find_state_dict_object(state_dict, MetadataIndex("a", [0]))
|
|
self.assertEqual(a, state_dict["a"])
|
|
a = find_state_dict_object(state_dict, MetadataIndex("a", index=99))
|
|
self.assertEqual(a, state_dict["a"])
|
|
|
|
b = find_state_dict_object(state_dict, MetadataIndex("b"))
|
|
self.assertEqual(b, state_dict["b"])
|
|
b = find_state_dict_object(state_dict, MetadataIndex("b", index=1))
|
|
self.assertEqual(b, state_dict["b"])
|
|
|
|
with self.assertRaisesRegex(ValueError, "FQN"):
|
|
find_state_dict_object(state_dict, MetadataIndex("c"))
|
|
with self.assertRaisesRegex(ValueError, "ShardedTensor"):
|
|
find_state_dict_object(state_dict, MetadataIndex("b", [1]))
|
|
|
|
@with_fake_comms(rank=0, world_size=2)
|
|
def test_sharded_tensor_lookup(self):
|
|
st = create_sharded_tensor(rank=0, world_size=2, shards_per_rank=3)
|
|
state_dict = {"st": st}
|
|
|
|
obj = find_state_dict_object(state_dict, MetadataIndex("st", [8]))
|
|
self.assertEqual(obj, st.local_shards()[1].tensor)
|
|
|
|
# good hint
|
|
obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=1))
|
|
self.assertEqual(obj, st.local_shards()[1].tensor)
|
|
|
|
# bad hint
|
|
obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=2))
|
|
self.assertEqual(obj, st.local_shards()[1].tensor)
|
|
|
|
# broken hint
|
|
obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=99))
|
|
self.assertEqual(obj, st.local_shards()[1].tensor)
|
|
|
|
with self.assertRaisesRegex(ValueError, "no offset was provided"):
|
|
find_state_dict_object(state_dict, MetadataIndex("st"))
|
|
|
|
with self.assertRaisesRegex(ValueError, "Could not find shard"):
|
|
find_state_dict_object(state_dict, MetadataIndex("st", [1]))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|