mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75844 Add extensible Sharder and ShardingPlanner for advanced users to extend and do custom sharding, some details: * ShardingPlanner is merely a interface that could be extended to plan advanced sharding strategies, it will produce a ShardingPlan that outlines the detailed sharding strategies for each submodule, note that it could also include custom sharders. * a custom Sharder is another interface that could be extended to define advanced sharding strategies for specific modules (i.e. a collection of `nn.EmbeddingBag`s.) The custom sharder defined by the user can be picked up by `share_module` API. `shard_module` will call `Sharder.shard` to the module, the output of the `Sharder.shard` will replace the original submodule. ghstack-source-id: 154978919 Test Plan: test_sharder test_sharding_plan Reviewed By: pritamdamania87 Differential Revision: D35658657 fbshipit-source-id: ad9deb39a6a0ab91ceb1de3a78c4709aa50ad699 (cherry picked from commit f9796471bd7f8ad94a9d3462cac30eacacf68b31)
166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
|
|
# Owner(s): ["oncall: distributed"]
|
|
import sys
|
|
import copy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.testing._internal.common_distributed import (
|
|
requires_nccl,
|
|
skip_if_lt_x_gpu,
|
|
)
|
|
from torch.distributed._shard import shard_module
|
|
from torch.distributed._shard.sharding_plan import ShardingPlan
|
|
from torch.distributed._shard.sharder import Sharder
|
|
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
|
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
|
|
|
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
|
|
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
|
TEST_GPU_NUM,
|
|
ShardedTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
# a simple collection of embedding bag implementation
|
|
class CustomEmbeddingBagCollection(nn.Module):
|
|
def __init__(self, num_bags, num_embeddings_per_bag, num_dims):
|
|
super().__init__()
|
|
self.num_bags = num_bags
|
|
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
|
|
|
|
for i in range(num_bags):
|
|
self.embedding_bags[f"embedding_bag_{i}"] = nn.EmbeddingBag(
|
|
num_embeddings_per_bag,
|
|
num_dims,
|
|
mode="sum")
|
|
|
|
def forward(self, inputs):
|
|
outputs = []
|
|
for bag in self.embedding_bags.values():
|
|
outputs.append(bag(inputs))
|
|
return torch.cat(outputs)
|
|
|
|
# a simple sharded version of EBC
|
|
class CustomShardedEBC(nn.Module):
|
|
def __init__(self, ebc, split_idx, specs):
|
|
super().__init__()
|
|
self.split_idx = split_idx
|
|
row_spec, col_spec = specs
|
|
|
|
# create embedding bags base on the spec
|
|
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
|
|
|
|
assert self.split_idx < ebc.num_bags
|
|
for i in range(ebc.num_bags):
|
|
bag_key = f"embedding_bag_{i}"
|
|
if i < self.split_idx:
|
|
shard_module(ebc, plan=ShardingPlan(plan={f"embedding_bags.{bag_key}.weight": row_spec}))
|
|
else:
|
|
shard_module(ebc, plan=ShardingPlan(plan={f"embedding_bags.{bag_key}.weight": col_spec}))
|
|
|
|
self.embedding_bags[bag_key] = ebc.embedding_bags[bag_key]
|
|
|
|
|
|
class CustomSharder(Sharder):
|
|
def __init__(self, devices, split_sharding_idx):
|
|
self.devices = devices
|
|
self.split_sharding_idx = split_sharding_idx
|
|
self.rowwise_spec = ChunkShardingSpec(dim=0, placements=devices)
|
|
self.colwise_spec = ChunkShardingSpec(dim=1, placements=devices)
|
|
|
|
def shard(self, ebc: nn.Module) -> nn.Module:
|
|
if not isinstance(ebc, CustomEmbeddingBagCollection):
|
|
raise RuntimeError("The custom sharder only supports CustomEmbeddingBagCollection")
|
|
|
|
return CustomShardedEBC(ebc, self.split_sharding_idx, (self.rowwise_spec, self.colwise_spec))
|
|
|
|
|
|
class TestCustomSharder(ShardedTensorTestBase):
|
|
|
|
@with_comms(init_rpc=False)
|
|
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
|
@requires_nccl()
|
|
def test_custom_sharder(self):
|
|
class MyModule(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.ebc = CustomEmbeddingBagCollection(10, 10, 8)
|
|
|
|
def forward(self, inputs):
|
|
return self.ebc(inputs)
|
|
|
|
custom_sharder = CustomSharder(
|
|
devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
|
|
split_sharding_idx=TEST_GPU_NUM // 2
|
|
)
|
|
|
|
sharding_plan = ShardingPlan(
|
|
plan={
|
|
"ebc": custom_sharder,
|
|
})
|
|
|
|
local_model = MyModule().cuda(self.rank)
|
|
sharded_model = copy.deepcopy(local_model)
|
|
|
|
# shard the module with the provided sharding plan
|
|
shard_module(sharded_model, sharding_plan)
|
|
|
|
# check to make sure the module already been sharded
|
|
emb_bags = sharded_model.ebc.embedding_bags
|
|
self.assertTrue(isinstance(emb_bags["embedding_bag_0"].weight, ShardedTensor))
|
|
self.assertTrue(isinstance(emb_bags["embedding_bag_9"].weight, ShardedTensor))
|
|
self.assertEqual(emb_bags["embedding_bag_0"].weight.sharding_spec(), custom_sharder.rowwise_spec)
|
|
self.assertEqual(emb_bags["embedding_bag_9"].weight.sharding_spec(), custom_sharder.colwise_spec)
|
|
|
|
# make sure we can run sharded computation and compare outputs
|
|
# with the local model version
|
|
input = torch.arange(8).reshape((2, 4)).cuda(self.rank)
|
|
local_output = local_model(input)
|
|
sharded_output = sharded_model(input)
|
|
|
|
self.assertEqual(local_output, sharded_output)
|
|
|
|
@with_comms(init_rpc=False)
|
|
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
|
@requires_nccl()
|
|
def test_custom_sharder_errors(self):
|
|
custom_sharder = CustomSharder(
|
|
devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
|
|
split_sharding_idx=TEST_GPU_NUM // 2
|
|
)
|
|
|
|
sharding_plan = ShardingPlan(
|
|
plan={
|
|
"": custom_sharder,
|
|
})
|
|
|
|
sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank)
|
|
|
|
with self.assertRaisesRegex(
|
|
KeyError, "path must not be empty for custom sharder!"
|
|
):
|
|
# shard the module with the provided sharding plan
|
|
shard_module(sharded_model, sharding_plan)
|
|
|
|
# test conflicted sharding plan
|
|
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"])
|
|
sharding_plan = ShardingPlan(
|
|
plan={
|
|
"embedding_bags.embedding_bag_0.weight": spec,
|
|
"embedding_bags": custom_sharder,
|
|
})
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "should not conflict with the submodule tree"
|
|
):
|
|
# shard the module with the provided sharding plan
|
|
shard_module(sharded_model, sharding_plan)
|