[PT-D] Enable megatron-lm style MLP layers (Changes mainly on sharded linear op) (#69735)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69735

We want to build a prototype of Megatron-LM so that we can apply PT-D op to models like transformer and other Meta flagship models like

The basic idea of Megatron-LM is as following:
1. Col-wise sharding of linear weight. Perform the linear op for the first layer.
2. Perform a math op (optional), such as ReLU or GeLU. We use GeLU in our example unit test. The input is from step 1.
3. Row-wise sharing of linear weight. Perform the linear op for the second layer. The input is from step 2.

We then save communications to concatenate the col-wise sharding results and spreading the input to different ranks for row-wise sharding.

The change is as following:
1. Return a ShardedTensor for the col-wise sharding in the sharded_linear op.
2. Return a PartialTensors for the row-wise sharding in the sharded_linear op.
3. Leverage APIs already defined for `reshard` to merge/aggregate local results to a fully sync local result if needed.
4. Add helper function to create sharded tensor based on the local result.
5. Add a unit test to test the Megatron-LM idea mentioned above and compare with local ops, including the grad and optimizer so that we can ensure the correctness of the implementation.
6. Refactor the unit test of sharded linear to reflect the changes in the code.
ghstack-source-id: 148273049

Test Plan: Unit test + CI

Reviewed By: pritamdamania87

Differential Revision: D32978221

fbshipit-source-id: 565fc92e7807e19d53b0261f8ace3945bef69e3e
(cherry picked from commit 344abe7520)
This commit is contained in:
Junjie Wang (PyTorch) 2022-02-02 22:06:43 -08:00 committed by PyTorch MergeBot
parent 19d0de8a57
commit 88547396eb
16 changed files with 734 additions and 94 deletions

View File

@ -0,0 +1,72 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch.distributed._shard import sharded_tensor
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
generate_chunk_sharding_specs_for_test,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestShardedTensorElementWiseOps(ShardedTensorTestBase):
def _run_sharded_elementwise_ops(self, spec, input_size, op):
torch.manual_seed(self.rank)
st = sharded_tensor.rand(spec, *input_size)
new_st = op(st)
local_shard = st.local_tensor()
new_st_local_shard = new_st.local_tensor()
self.assertEqual(
op(local_shard),
new_st_local_shard,
)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_gelu(self):
specs = generate_chunk_sharding_specs_for_test(
0
) + generate_chunk_sharding_specs_for_test(1)
for spec in specs:
self._run_sharded_elementwise_ops(spec, [12, 17], torch.nn.functional.gelu)
self._run_sharded_elementwise_ops(spec, [18, 21], torch.nn.functional.gelu)
self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.gelu)
self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.gelu)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_relu(self):
specs = generate_chunk_sharding_specs_for_test(
0
) + generate_chunk_sharding_specs_for_test(1)
for spec in specs:
self._run_sharded_elementwise_ops(spec, [12, 17], torch.nn.functional.relu)
self._run_sharded_elementwise_ops(spec, [18, 21], torch.nn.functional.relu)
self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.relu)
self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.relu)
if __name__ == "__main__":
run_tests()

View File

@ -21,6 +21,7 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
clone_module_parameter,
generate_chunk_sharding_specs_for_test,
generate_local_weight_sharding_params_for_test,
)
@ -64,8 +65,8 @@ class TestShardedEmbedding(ShardedTensorTestBase):
)
# Copy the weights from local embedding
sharded_embedding.weight = torch.nn.Parameter(
local_embedding.weight.detach().clone()
sharded_embedding.weight = clone_module_parameter(
local_embedding, "weight"
)
# Shard the parameter.

View File

@ -21,6 +21,7 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
clone_module_parameter,
generate_chunk_sharding_specs_for_test,
generate_local_weight_sharding_params_for_test,
)
@ -71,8 +72,8 @@ class TestShardedEmbeddingBag(ShardedTensorTestBase):
)
# Copy the weights from local embedding bag.
sharded_embedding_bag.weight = torch.nn.Parameter(
local_embedding_bag.weight.detach().clone()
sharded_embedding_bag.weight = clone_module_parameter(
local_embedding_bag, "weight"
)
# Shard the parameter.

View File

@ -1,21 +1,24 @@
# Owner(s): ["oncall: distributed"]
import copy
import sys
import torch
import torch.distributed as dist
from torch.distributed._shard import shard_parameter
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor,
)
from torch.distributed._shard.sharded_tensor import (
empty,
_collect_local_shard,
_reshard_output,
)
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
EnumerableShardingSpec,
ShardMetadata
)
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor,
ShardMetadata,
)
from torch.testing._internal.common_distributed import (
requires_nccl,
@ -31,6 +34,7 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
clone_module_parameter,
generate_chunk_sharding_specs_for_test,
generate_local_weight_sharding_params_for_test,
)
@ -44,16 +48,17 @@ if TEST_WITH_DEV_DBG_ASAN:
class TestShardedTensorOpsLinear(ShardedTensorTestBase):
def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim):
def _run_sharded_linear(
self, spec, input_size, linear_size, sharded_dim
):
# Use same seed.
torch.manual_seed(0)
local_linear = torch.nn.Linear(*linear_size).cuda(self.rank)
sharded_linear = torch.nn.Linear(*linear_size)
# Copy the weights and bias from local linear
sharded_linear.weight = torch.nn.Parameter(local_linear.weight.detach().clone())
sharded_linear.bias = torch.nn.Parameter(local_linear.bias.detach().clone())
sharded_linear.weight = clone_module_parameter(local_linear, "weight")
sharded_linear.bias = clone_module_parameter(local_linear, "bias")
# Shard the parameter.
shard_parameter(sharded_linear, "weight", spec)
@ -61,6 +66,11 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
# Run sharded computation
torch.manual_seed(self.rank) # inputs different on each rank
inp = torch.rand(*input_size).cuda(self.rank)
reshard_spec = copy.deepcopy(spec)
reshard_spec.dim = 0
sharded_linear = _collect_local_shard(
_reshard_output(sharded_linear, reshard_spec)
)
sharded_output = sharded_linear(inp)
# Run local computation
@ -76,6 +86,11 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
sharded_output = torch.nn.functional.linear(
inp, sharded_linear.weight, sharded_linear.bias
)
sharded_output = sharded_output.reshard(reshard_spec).local_tensor()
# When local tensor only has one dimension, we increase one more dimension
# for reshard. We need to squeeze the # of dimensions manually.
if inp.dim() == 1:
sharded_output = sharded_output.squeeze(reshard_spec.dim)
self.assertEqual(local_output, sharded_output)
# Compute loss and run backward pass.
@ -84,7 +99,7 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
local_grad = local_linear.weight.grad
# Verify that both weight and bias in the sharded linear has non-None grad.
sharded_weight = sharded_linear.weight.local_shards()[0].tensor
sharded_weight = sharded_linear.weight.local_tensor()
self.assertNotEqual(sharded_linear.bias.grad, None)
self.assertNotEqual(sharded_weight.grad, None)
@ -94,9 +109,11 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank
)
local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos, chunk_size)
local_bias_grad = local_linear.bias.grad
dist.all_reduce(local_bias_grad)
# Test backward gradient calculation.
self.assertEqual(sharded_linear.bias.grad, local_linear.bias.grad)
self.assertEqual(sharded_linear.bias.grad, local_bias_grad)
self.assertEqual(sharded_weight.grad, local_grad_narrowed)
# Test optimizer.
@ -106,9 +123,13 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
self.assertNotEqual(previous, local_linear.weight)
previous_sharded_weight = sharded_weight.clone()
previous_sharded_bias = sharded_linear.bias.clone()
sharded_optim = ShardedOptimizer(dict(named_params_with_sharded_tensor(sharded_linear)), torch.optim.SGD, lr=0.1)
sharded_optim = ShardedOptimizer(
dict(named_params_with_sharded_tensor(sharded_linear)),
torch.optim.SGD,
lr=0.1,
)
sharded_optim.step()
sharded_weight = sharded_linear.weight.local_shards()[0].tensor
sharded_weight = sharded_linear.weight.local_tensor()
local_weight_narrowed = local_linear.weight.narrow(
sharded_dim, start_pos, chunk_size
)
@ -169,7 +190,7 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
for spec in generate_chunk_sharding_specs_for_test(0):
fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
shard_parameter(fc1, "bias", spec)
with self.assertRaisesRegex(TypeError, 'input and bias need to be torch.Tensor'):
with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'):
fc1(torch.rand(10, 10).cuda(self.rank))
fc2 = torch.nn.Linear(10, 10).cuda(self.rank)

View File

@ -0,0 +1,224 @@
# Owner(s): ["oncall: distributed"]
import copy
import sys
import torch
import torch.distributed as dist
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor,
)
from torch.distributed._shard import (
shard_parameter,
)
from torch.distributed._shard.sharded_tensor import (
_collect_local_shard,
_reshard_output,
)
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
clone_module_parameter,
generate_chunk_sharding_specs_for_test,
generate_local_weight_sharding_params_for_test,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestShardedTensorMegatronLinear(ShardedTensorTestBase):
class SimpleMegatronLM(torch.nn.Module):
def __init__(self, linear_size, rank=None):
super().__init__()
self.fc1 = torch.nn.Linear(*linear_size[0])
self.gelu = torch.nn.GELU()
self.fc2 = torch.nn.Linear(*linear_size[1])
if rank:
self.fc1.cuda(rank)
self.fc2.cuda(rank)
def forward(self, inp):
return self.fc2(self.gelu(self.fc1(inp)))
def _run_megatron_linear(self, spec, input_size, linear_size):
def _weight_override(module_dst, module_src):
module_dst.fc1.weight = clone_module_parameter(module_src.fc1, "weight")
module_dst.fc1.bias = clone_module_parameter(module_src.fc1, "bias")
module_dst.fc2.weight = clone_module_parameter(module_src.fc2, "weight")
module_dst.fc2.bias = clone_module_parameter(module_src.fc2, "bias")
def _shard_parameter(module, spec):
shard_parameter(module.fc1, "weight", spec[0])
shard_parameter(module.fc2, "weight", spec[1])
def _get_weight_grad(module):
return (module.fc1.weight.grad, module.fc2.weight.grad)
def _get_bias_grad(module):
return (module.fc1.bias.grad, module.fc2.bias.grad)
def _get_weights(module):
return (module.fc1.weight, module.fc2.weight)
def _get_bias(module):
return (module.fc1.bias, module.fc2.bias)
def _get_weight_local_shard(module):
return (
module.fc1.weight.local_tensor(),
module.fc2.weight.local_tensor(),
)
# Use same seed.
torch.manual_seed(0)
local_megatron_lm = self.SimpleMegatronLM(linear_size, rank=self.rank).cuda(
self.rank
)
sharded_megatron_lm = self.SimpleMegatronLM(linear_size)
_weight_override(sharded_megatron_lm, local_megatron_lm)
# Shard the parameter. First col-wise sharding and then row-wise
_shard_parameter(sharded_megatron_lm, spec)
# Run sharded computation
torch.manual_seed(self.rank) # inputs different on each rank
inp = torch.rand(*input_size).cuda(self.rank)
reshard_spec = copy.deepcopy(spec[1])
reshard_spec.placements.sort(key=lambda placement: placement.rank())
reshard_spec.dim = 0
sharded_megatron_lm = _collect_local_shard(
_reshard_output(sharded_megatron_lm, reshard_spec)
)
sharded_output = sharded_megatron_lm(inp)
# Run local computation
local_output = local_megatron_lm(inp)
# Verify
self.assertEqual(local_output, sharded_output)
# Compute loss and run backward pass.
local_output.sum().backward()
sharded_output.sum().backward()
(
local_weight_grad_fc1,
local_weight_grad_fc2,
) = _get_weight_grad(local_megatron_lm)
local_bias_grad_fc1, local_bias_grad_fc2 = _get_bias_grad(local_megatron_lm)
# Verify that weights in both layers and biases in the sharded linear has non-None grad.
(
sharded_weight_fc1,
sharded_weight_fc2,
) = _get_weight_local_shard(sharded_megatron_lm)
bias_grad_fc1, bias_grad_fc2 = _get_bias_grad(sharded_megatron_lm)
self.assertNotEqual(sharded_weight_fc1.grad, None)
self.assertNotEqual(sharded_weight_fc2.grad, None)
self.assertNotEqual(bias_grad_fc1, None)
self.assertNotEqual(bias_grad_fc2, None)
# Shard the local linear's weight grad so that we can compare.
dist.all_reduce(local_weight_grad_fc1)
dist.all_reduce(local_weight_grad_fc2)
dist.all_reduce(local_bias_grad_fc1)
dist.all_reduce(local_bias_grad_fc2)
local_weight_fc1, local_weight_fc2 = _get_weights(local_megatron_lm)
(
start_pos_fc1,
chunk_size_fc1,
) = generate_local_weight_sharding_params_for_test(
local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank
)
local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow(
0, start_pos_fc1, chunk_size_fc1
)
(
start_pos_fc2,
chunk_size_fc2,
) = generate_local_weight_sharding_params_for_test(
local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank
)
local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow(
1, start_pos_fc2, chunk_size_fc2
)
# Test backward gradient calculation.
self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1)
self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2)
self.assertEqual(bias_grad_fc1, local_bias_grad_fc1)
self.assertEqual(bias_grad_fc2, local_bias_grad_fc2)
# Test optimizer.
bias_fc1, bias_fc2 = _get_bias(sharded_megatron_lm)
local_bias_fc1, local_bias_fc2 = _get_bias(local_megatron_lm)
self.assertEqual(bias_fc1, local_bias_fc1)
self.assertEqual(bias_fc2, local_bias_fc2)
self.assertEqual(bias_fc1.grad, local_bias_fc1.grad)
self.assertEqual(bias_fc2.grad, local_bias_fc2.grad)
previous_sharded_weight_fc1 = sharded_weight_fc1.clone()
previous_sharded_weight_fc2 = sharded_weight_fc2.clone()
previous_bias_fc1 = bias_fc1.clone()
previous_bias_fc2 = bias_fc2.clone()
optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
optim.step()
sharded_optim = ShardedOptimizer(
dict(named_params_with_sharded_tensor(sharded_megatron_lm)),
torch.optim.SGD,
lr=0.1,
)
sharded_optim.step()
local_weight_fc1_narrowed = local_weight_fc1.narrow(
0, start_pos_fc1, chunk_size_fc1
)
local_weight_fc2_narrowed = local_weight_fc2.narrow(
1, start_pos_fc2, chunk_size_fc2
)
# Test weight value after optimizer.
self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size())
self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size())
self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1)
self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2)
self.assertEqual(sharded_weight_fc1, local_weight_fc1_narrowed)
self.assertEqual(sharded_weight_fc2, local_weight_fc2_narrowed)
# Test bias value after optimizer.
local_bias_fc1, local_bias_fc2 = _get_bias(local_megatron_lm)
self.assertNotEqual(previous_bias_fc1, bias_fc1)
self.assertEqual(bias_fc1, local_bias_fc1)
self.assertNotEqual(previous_bias_fc2, bias_fc2)
self.assertEqual(bias_fc2, local_bias_fc2)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_megatron_two_layer_prototype(self):
colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)
rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)
for spec in zip(colwise_sharding_spec, rowwise_sharding_spec):
self._run_megatron_linear(spec, [22, 17], [[17, 12], [12, 29]])
self._run_megatron_linear(spec, [28, 21], [[21, 11], [11, 29]])
self._run_megatron_linear(spec, [37, 23], [[23, 13], [13, 24]])
self._run_megatron_linear(spec, [24, 15], [[15, 14], [14, 20]])
if __name__ == "__main__":
run_tests()

View File

@ -12,6 +12,7 @@ from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
get_split_size,
get_chunked_dim_size,
get_chunk_sharding_params,
)
from torch.testing._internal.common_utils import (
@ -248,5 +249,32 @@ class TestShardingSpec(TestCase):
self.assertEqual(1, get_chunked_dim_size(13, 4, 3))
self.assertEqual(0, get_chunked_dim_size(5, 2, 3))
def test_get_chunk_sharding_params(self):
ranks = [
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
]
spec = ChunkShardingSpec(
dim=0,
placements=ranks,
)
result = get_chunk_sharding_params(21, 4, spec, 1)
self.assertEqual(6, result[0])
self.assertEqual(6, result[1])
result = get_chunk_sharding_params(21, 4, spec, 3)
self.assertEqual(18, result[0])
self.assertEqual(3, result[1])
ranks[1], ranks[2] = ranks[2], ranks[1]
ranks[0], ranks[3] = ranks[3], ranks[0]
spec.placements = ranks
result = get_chunk_sharding_params(21, 4, spec, 1)
self.assertEqual(12, result[0])
self.assertEqual(6, result[1])
result = get_chunk_sharding_params(21, 4, spec, 3)
self.assertEqual(0, result[0])
self.assertEqual(6, result[1])
if __name__ == '__main__':
run_tests()

View File

@ -203,9 +203,11 @@ WINDOWS_BLOCKLIST = [
"distributed/elastic/agent/server/test/api_test",
"distributed/elastic/multiprocessing/api_test",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_tensor/test_megatron_prototype",
"distributed/_shard/sharded_tensor/test_sharded_tensor",
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
"distributed/_shard/sharded_tensor/test_partial_tensor",
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
"distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
@ -219,9 +221,11 @@ ROCM_BLOCKLIST = [
"distributed/rpc/test_faulty_agent",
"distributed/rpc/test_tensorpipe_agent",
"distributed/rpc/cuda/test_tensorpipe_agent",
"distributed/_shard/sharded_tensor/test_megatron_prototype",
"distributed/_shard/sharded_tensor/test_sharded_tensor",
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
"distributed/_shard/sharded_tensor/test_partial_tensor",
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
"distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
@ -360,9 +364,11 @@ DISTRIBUTED_TESTS = [
"distributed/elastic/utils/distributed_test",
"distributed/elastic/multiprocessing/api_test",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_tensor/test_megatron_prototype",
"distributed/_shard/sharded_tensor/test_sharded_tensor",
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
"distributed/_shard/sharded_tensor/test_partial_tensor",
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
"distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",

View File

@ -439,6 +439,12 @@ def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
"""
Hook a module with local shards collection in the forward pass.
This API is typically used to convert a sharded representation back to data parallel
representation. In particular, it returns the local tensor for this Shard. If the
size along the sharding dimension for the local tensor is 1, this dimension is removed
from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
a local Tensor of size [16] across each rank and not [1, 16] across each rank.
Args:
module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
@ -448,7 +454,12 @@ def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
def hook_func(_module, _input, output):
if isinstance(output, ShardedTensor):
return output.local_tensor()
return output
local_tensor = output.local_tensor()
# Squeeze the # of dimensions manually.
if local_tensor.size(output._sharding_spec.dim) == 1: # type: ignore[attr-defined]
local_tensor = local_tensor.squeeze(
output._sharding_spec.dim # type: ignore[attr-defined]
)
return local_tensor
module.register_forward_hook(hook_func)
return module

View File

@ -1,8 +1,7 @@
from .init import kaiming_uniform_, normal_, uniform_
from .linear import sharded_linear
import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
from .binary_cmp import equal, allclose
from .embedding import sharded_embedding
from .embedding_bag import sharded_embedding_bag
from .binary_cmp import (
equal,
allclose
)
from .init import kaiming_uniform_, normal_, uniform_
from .linear import sharded_linear

View File

@ -16,7 +16,6 @@ from torch.distributed.nn.functional import (
def _handle_col_wise_sharding_base(
op_func,
sharding_dim_size,
col_dim,
input,
world_size,
@ -39,7 +38,6 @@ def _handle_col_wise_sharding_base(
Args:
op_func: operator which is applied to the input tensor.
sharding_dim_size: the max size of the column each rank gets.
col_dim: dim of result tensor after the operation.
input: tensor to be applied op on.
world_size: number of ranks.
@ -90,7 +88,7 @@ def _handle_col_wise_sharding_base(
# Distribute results to each rank with col rearrangement.
output = _result_distribute_with_col_rearrange(
results, input, sharding_dim_size, world_size, weight, pg
results, input, world_size, weight, pg
)
# transpose the output and return result.
@ -98,7 +96,7 @@ def _handle_col_wise_sharding_base(
def _result_distribute_with_col_rearrange(
results, input, sharding_dim_size, world_size, weight, pg
results, input, world_size, weight, pg
):
"""
For col-wise sharding of weight, we need to distribute
@ -111,7 +109,6 @@ def _result_distribute_with_col_rearrange(
results: results from ops applied to inputs from all ranks.
We need to distribute them back to their original ranks.
input: tensor to be applied op to.
sharding_dim_size: the max size of the column each rank gets.
world_size: number of ranks.
weight: shareded weight tensor.
pg: process group.
@ -119,6 +116,8 @@ def _result_distribute_with_col_rearrange(
Return: column rearranged result.
"""
# Process results and outputs for all2all.
sharding_dim = weight._sharding_spec.dim
sharding_dim_size = weight.size(sharding_dim)
dims = list(results[0].size())
dims[0] = sharding_dim_size
output = torch.empty(*dims, device=input.device)

View File

@ -0,0 +1,38 @@
import copy
import torch
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
Shard,
ShardedTensor,
)
def register_elementwise_op(op):
@sharded_op_impl(op)
def elementwise_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
input = args[0]
# Validate types
if not isinstance(input, ShardedTensor):
raise TypeError("input needs to be a ShardedTensor")
local_shards_new = []
for local_shard in input.local_shards():
local_shards_new.append(Shard(op(local_shard.tensor), local_shard.metadata))
# TODO: After a new API for sharded tensor creation, we need to replace this.
# https://github.com/pytorch/pytorch/issues/72092
new_st = ShardedTensor._init_from_local_shards(
local_shards_new, input.size(), process_group=pg
)
# Manually set sharding_spec
new_st._sharding_spec = copy.deepcopy(input._sharding_spec)
return new_st
register_elementwise_op(torch.nn.functional.gelu)
register_elementwise_op(torch.nn.functional.relu)

View File

@ -104,7 +104,7 @@ def sharded_embedding(types, args, kwargs, pg):
norm_type = kwargs.get("norm_type")
padding_idx = kwargs.get("padding_idx")
local_shard = weight.local_shards()[0].tensor.contiguous()
local_shard = weight.local_tensor().contiguous()
sharding_dim = weight._sharding_spec.dim
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
@ -224,7 +224,6 @@ def _handle_col_wise_sharding(
output = _handle_col_wise_sharding_base(
torch.nn.functional.embedding,
weight.size(1),
len(input.size()),
input,
world_size,

View File

@ -123,7 +123,7 @@ def sharded_embedding_bag(types, args, kwargs, pg):
include_last_offset = kwargs.get("include_last_offset")
padding_idx = kwargs.get("padding_idx")
local_shard = weight.local_shards()[0].tensor.contiguous()
local_shard = weight.local_tensor().contiguous()
sharding_dim = weight._sharding_spec.dim
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
@ -319,7 +319,6 @@ def _handle_col_wise_sharding(
output = _handle_col_wise_sharding_base(
torch.nn.functional.embedding_bag,
weight.size(1),
1,
input,
world_size,

View File

@ -1,20 +1,29 @@
import copy
from typing import List, cast
import torch
import torch.distributed as dist
from torch.autograd import Function
from torch.distributed.nn.functional import (
all_gather,
all_to_all_single,
reduce_scatter,
)
from torch.distributed._shard.sharded_tensor import sharded_op_impl, ShardedTensor
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
_PartialTensor,
Shard,
ShardedTensor,
ShardMetadata,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharding_spec._internals import (
get_split_size,
get_chunked_dim_size,
get_chunk_sharding_params,
)
from ._common import (
_handle_col_wise_sharding_base,
_result_distribute_with_col_rearrange,
)
@ -27,6 +36,9 @@ def sharded_linear(types, args, kwargs, pg):
1. Supports only sharding of ``weight``.
2. Supports only ``ChunkShardingSpec``.
3. Supports only a single local shard per rank.
4. Tailored for Megatron-LM style model(tensor) parallelism. Further API
calls are needed if a fully synced local tensor is needed.
Megatron-LM paper link: https://arxiv.org/abs/1909.08053
Based on the dimension that the weight is sharded on, there are two
algorithms:
@ -50,11 +62,13 @@ def sharded_linear(types, args, kwargs, pg):
size that we need for the global result which would be (13 x 16)
multiplied by (16 x 17). But the final result needs to be aggregated
across the rest of the ranks.
3. Now the local matmul results are aggregated and shared to all the
corresponding ranks using a reduce_scatter operation ensuring each rank
3. Here we just return the partial result here. One can call API
aggregate_partial_tensor_list to get the aggregated final result.
The API uses a reduce_scatter operation ensuring each rank
aggregates its own result. This is essentially a sum operation across
all the (13 x 17) local computations we did for each rank.
4. Finally, we add the bias term locally to the final computation.
4. For partial result, we only add 1 / n of the bias term to the partial
result. n is # of all GPUs.
COLWISE SHARDING
================
@ -72,53 +86,39 @@ def sharded_linear(types, args, kwargs, pg):
2. Next we perform local matmuls by multiplying each input (13 x 17)
with the local shard (17 x 4) (transposed). This results in 4 (13 x 4)
matrices on each rank.
3. Next, we concat these 4 matrices and perform an all2all to share the
appropriate (13 x 4) matrices to each rank.
4. Now, each rank receives a (13 x 16) matrix which is basically the size
of the result we need.
3. Next, we stack them into a (4 x 13 x 4) tensor and build a sharded
tensor across 4 ranks.
4. To merge them into a fully-sync local tensor, one can call API
merge_sharded_local_results.
This API concat these 4 matrices and perform an all2all to share the
appropriate (13 x 4) matrices to each rank. Specifically, each rank
receives a (13 x 16) matrix which is basically the size of the result.
5. If placements are not in order any appropriate rearrangement of rows
are done for the (13 x 16) matrix and finally the bias term is added.
"""
# Validate input params
_validate_linear_op_param(args, kwargs)
input = args[0]
weight = args[1]
bias = args[2]
# Validate types
if not isinstance(input, torch.Tensor) or not isinstance(bias, torch.Tensor):
raise TypeError("input and bias need to be torch.Tensor")
if not isinstance(weight, ShardedTensor):
raise TypeError("weight needs to be ShardedTensor")
if len(input.size()) < 1:
raise ValueError("Input needs to have at least 1 dim")
weight_size = cast(torch.Size, weight.size())
if len(weight_size) != 2:
raise ValueError("Weight needs to have exactly 2 dims")
if len(bias.size()) != 1:
raise ValueError("Bias needs to have exactly 1 dim")
if input.size()[-1] != weight_size[1]:
raise ValueError(
f"Input dim: {input.size()[-1]} does not match "
f"appropriate weight dim: {weight_size[1]}"
)
if not isinstance(weight._sharding_spec, ChunkShardingSpec):
raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!")
if len(weight.local_shards()) != 1:
raise ValueError("Only one local shard supported!")
local_shard = weight.local_tensor()
local_shard_t = local_shard.t().contiguous()
sharding_dim = weight._sharding_spec.dim
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
if sharding_dim == 1:
return _handle_row_wise_sharding(
if sharding_dim == 1 and isinstance(input, torch.Tensor):
return _handle_row_wise_sharding_tensor(
input, world_size, weight, rank, local_shard_t, bias, pg
)
elif sharding_dim == 1 and isinstance(input, ShardedTensor):
return _handle_row_wise_sharding_sharded_tensor(
input, world_size, weight, local_shard_t, bias, pg
)
elif sharding_dim == 0:
return _handle_col_wise_sharding(
input, world_size, weight, local_shard_t, bias, pg
input, world_size, weight, rank, local_shard_t, bias, pg
)
else:
raise RuntimeError(
@ -126,38 +126,106 @@ def sharded_linear(types, args, kwargs, pg):
)
def _handle_col_wise_sharding(input, world_size, weight, local_shard_t, bias, pg):
def _validate_linear_op_param(args, kwargs):
"""
Validate input params of sharded embedding op.
Args:
input: input of the linear layer.
weight: shareded weight tensor.
kwargs: same as normal Linear.
Return: None.
"""
input = args[0]
weight = args[1]
bias = args[2]
# Validate types
if not isinstance(input, torch.Tensor) and not isinstance(input, ShardedTensor):
raise TypeError("input needs to be either torch.Tensor or ShardedTensor")
if not isinstance(bias, torch.Tensor):
raise TypeError("bias needs to be torch.Tensor")
if not isinstance(weight, ShardedTensor):
raise TypeError("weight needs to be ShardedTensor")
if len(input.size()) < 1: # type: ignore[arg-type]
raise ValueError("Input needs to have at least 1 dim")
weight_size = cast(torch.Size, weight.size())
if len(weight_size) != 2:
raise ValueError("Weight needs to have exactly 2 dims")
if len(bias.size()) != 1:
raise ValueError("Bias needs to have exactly 1 dim")
if input.size()[-1] != weight_size[1]: # type: ignore[index]
raise ValueError(
f"Input dim: {input.size()[-1]} does not match " # type: ignore[index]
f"appropriate weight dim: {weight_size[1]}"
)
if not isinstance(weight._sharding_spec, ChunkShardingSpec):
raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!")
if len(weight.local_shards()) != 1:
raise ValueError("Only one local shard supported!")
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg):
"""
Entry-point function to handle the logic of col-wise sharding of weight
for Linear. (Detailed explanations of the logic can be found in the
comment for sharded_linear.)
When the local tensor only has one dimension, we increase one more dimension
for reshard. We need to do squeeze manually to reduce the dimension later-on.
For example, if we have:
input: size[15]
weight: size[15, 16]
world_size: 4
In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4]
tensor and generate a sharded tenor sharded by dim 1.
For the rest situations, we just simply concatenate local tensors. No more actions
are needed afterward.
Args:
input: matrix to be multiplied with the sharded weight.
world_size: number of ranks.
weight: shareded weight tensor.
rank: # of cuda process.
local_shard_t: row-wise shared local weight used for lookup.
bias: bias term of linear op.
pg: process group.
Returns: final result of linear operation.
Returns:
A :class:`ShardedTensor` object which filled with local intermediate results.
"""
return (
_handle_col_wise_sharding_base(
torch.matmul,
weight.size(0),
len(input.size()) - 1,
input,
world_size,
weight,
local_shard_t,
pg,
)
+ bias
# allgather the inputs first.
gathered_inputs = all_gather(input, group=pg)
(start_pos, chunk_size) = get_chunk_sharding_params(
bias.size(0), world_size, weight._sharding_spec, rank
)
local_bias = _BiasTensorNarrow.apply(
world_size, start_pos, chunk_size, weight, pg, bias
)
results = [None] * world_size
indices = {}
for idx, placement in enumerate(weight._sharding_spec.placements):
indices[placement.rank()] = idx
for i, inp in enumerate(gathered_inputs):
results[indices[i]] = inp.matmul(local_shard_t) + local_bias
# When the local result only has one dimension, we need to make sure
# it does not shard by dim 0. So reshard can work properly.
if results[0].dim() == 1: # type: ignore[attr-defined]
result = torch.stack(results) # type: ignore[arg-type]
else:
result = torch.cat(results) # type: ignore[arg-type]
return _init_sharded_tensor_from_local_result(
weight, result, 0, -1, world_size, pg # type: ignore[arg-type]
)
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg):
def _handle_row_wise_sharding_tensor(
input, world_size, weight, rank, local_shard_t, bias, pg
):
"""
Entry-point function to handle the logic of row-wise sharding of weight
for Linear. (Detailed explanations of the logic can be found in the
@ -172,7 +240,8 @@ def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bi
bias: bias term of linear op.
pg: process group.
Returns: final result of linear operation.
Returns:
A :class:`_PartialTensor` object which stores the partial local result.
"""
# alltoall to gather all the appropriate inputs.
input_t = input.transpose(0, -1).contiguous()
@ -220,15 +289,146 @@ def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bi
gathered_input = gathered_input.transpose(0, -1)
# Perform local matmuls for all shards
shard_size = local_shard_t.size()[0]
results = []
shard_size = local_shard_t.size()[0]
for r in range(world_size):
inp = torch.narrow(gathered_input, -1, r * shard_size, shard_size)
results.append(inp.matmul(local_shard_t))
results.append(
inp.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)
)
# Gather all the results appropriately.
local_result = torch.empty_like(results[rank])
local_result = reduce_scatter(local_result, results, group=pg)
# Return the partial local result.
return _PartialTensor(torch.cat(results), pg)
# Return the appropriate local result.
return local_result + bias
def _handle_row_wise_sharding_sharded_tensor(
input, world_size, weight, local_shard_t, bias, pg
):
"""
Entry-point function to handle the logic of row-wise sharding of weight
for Linear when the input is a sharded tensor. (Detailed explanations
of the logic can be found in the comment for sharded_linear.)
Args:
input: matrix to be multiplied with the sharded weight.
world_size: number of ranks.
weight: shareded weight tensor.
local_shard_t: row-wise shared local weight used for lookup.
bias: bias term of linear op.
pg: process group.
Returns:
A :class:`_PartialTensor` object which stores the partial local result.
"""
results = []
local_shard = input.local_shards()[0].tensor
indices = [0] * world_size
reaggrance_partial = False
for idx, placement in enumerate(input._sharding_spec.placements):
indices[placement.rank()] = idx
if idx != placement.rank():
reaggrance_partial = True
for tensor in torch.tensor_split(local_shard, world_size):
results.append(
tensor.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)
)
if reaggrance_partial:
results = [results[idx] for idx in indices]
# Return the partial local result.
return _PartialTensor(torch.cat(results), pg)
def _init_sharded_tensor_from_local_result(
sharded_tensor,
local_result,
tensor_shard_dim,
result_shard_dim,
world_size,
pg,
):
"""
Given a sharded tensor and local_result from an op on top of it. We want
to create a new sharded tensor from the local_result so that the the next
op can be performed on the basis of the new sharded tensor. This can seen
as the last step of the first phase of the Megatron-LM style model(tensor)
parallelism.
Args:
sharded_tensor: Sharded tensor which the op was performed on.
local_result: A tensor which is from the op performed on the local_shard of
the sharded_tensor.
tensor_shard_dim: Dim which the tensor is sharded on.
result_shard_dim: Dim which the new sharded tensor will be sharded on.
world_size: number of ranks.
pg (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Return:
A :class:`ShardedTensor` object which filled with local intermediate results.
"""
sharded_weight_metadata = copy.deepcopy(sharded_tensor.local_shards()[0].metadata)
current_offsets = [0] * local_result.dim()
current_offsets[result_shard_dim] = sharded_weight_metadata.shard_offsets[
tensor_shard_dim
]
global_size = list(local_result.size())
global_size[result_shard_dim] = sharded_tensor.size(tensor_shard_dim)
local_shard_metadata = ShardMetadata(
shard_offsets=current_offsets,
shard_sizes=list(local_result.size()),
placement=sharded_weight_metadata.placement,
)
local_shards = [Shard(local_result, local_shard_metadata)]
new_st = ShardedTensor._init_from_local_shards(
local_shards, tuple(global_size), process_group=pg
)
# Manually set sharding_spec
new_st._sharding_spec = copy.deepcopy(sharded_tensor._sharding_spec)
new_st._sharding_spec.dim = result_shard_dim
return new_st
class _BiasTensorNarrow(Function):
"""
Since we now return the intermediate results in a col-wise sharding. We
need to narrow the bias term in the forward while doing backward, we need
to gather all gradients of narrowed bias across all ranks.
"""
@staticmethod
def forward(ctx, world_size, start_pos, chunk_size, weight, pg, bias):
ctx.weight = weight
ctx.pg = pg
ctx.world_size = world_size
return torch.narrow(bias, 0, start_pos, chunk_size)
@staticmethod
def backward(ctx, grad_output):
results = []
for idx in range(ctx.world_size):
results.append(grad_output.clone())
return (None, None, None, None, None) + (
_result_distribute_with_col_rearrange(
results, grad_output, ctx.world_size, ctx.weight, ctx.pg
),
)
class _BiasTensorPartial(Function):
"""
Since we now only return partial results in a row-wise sharding. We need to
divide the bias term by the world size in the forward while doing backward,
we need to skip this division op.
"""
@staticmethod
def forward(ctx, world_size, bias):
ctx.world_size = world_size
return torch.div(bias, world_size)
@staticmethod
def backward(ctx, grad_output):
return (None, grad_output)

View File

@ -150,3 +150,30 @@ def get_chunked_dim_size(dim_size, split_size, idx):
An int indicating the dim size of the chunk.
"""
return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0)
def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank):
"""
Generate the start pos and offset length for the current rank for
chunk sharding.
Args:
sharding_dim_size(int): The dimension length which we shard on.
world_size(int): number of ranks.
spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`):
sharding spec.
rank(int): # of cuda process.
Returns:
start_pos(int): start position of sharded tensor on the given rank.
chunk_size(int): chunk size of sharded tensor on the given rank.
"""
split_size = get_split_size(sharding_dim_size, world_size)
current_offsets = 0
start_pos = current_offsets
for idx, placement in enumerate(spec.placements):
chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
if rank == placement.rank():
start_pos = current_offsets
break
current_offsets += chunk_size
return start_pos, chunk_size

View File

@ -1,3 +1,4 @@
import torch
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
@ -70,3 +71,17 @@ def generate_local_weight_sharding_params_for_test(
break
current_offsets += chunk_size
return start_pos, chunk_size
def clone_module_parameter(module, param_name):
"""
Clone a parameter from a given existing module.
Args:
module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned.
param_name (str): Name of the parameter of ``module`` that needs to be cloned.
Returns: cloned tensor as :class:`torch.nn.Parameter`.
"""
tensor = getattr(module, param_name)
return torch.nn.Parameter(tensor.detach().clone())