mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[PT-D] Enable megatron-lm style MLP layers (Changes mainly on sharded linear op) (#69735)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69735
We want to build a prototype of Megatron-LM so that we can apply PT-D op to models like transformer and other Meta flagship models like
The basic idea of Megatron-LM is as following:
1. Col-wise sharding of linear weight. Perform the linear op for the first layer.
2. Perform a math op (optional), such as ReLU or GeLU. We use GeLU in our example unit test. The input is from step 1.
3. Row-wise sharing of linear weight. Perform the linear op for the second layer. The input is from step 2.
We then save communications to concatenate the col-wise sharding results and spreading the input to different ranks for row-wise sharding.
The change is as following:
1. Return a ShardedTensor for the col-wise sharding in the sharded_linear op.
2. Return a PartialTensors for the row-wise sharding in the sharded_linear op.
3. Leverage APIs already defined for `reshard` to merge/aggregate local results to a fully sync local result if needed.
4. Add helper function to create sharded tensor based on the local result.
5. Add a unit test to test the Megatron-LM idea mentioned above and compare with local ops, including the grad and optimizer so that we can ensure the correctness of the implementation.
6. Refactor the unit test of sharded linear to reflect the changes in the code.
ghstack-source-id: 148273049
Test Plan: Unit test + CI
Reviewed By: pritamdamania87
Differential Revision: D32978221
fbshipit-source-id: 565fc92e7807e19d53b0261f8ace3945bef69e3e
(cherry picked from commit 344abe7520)
This commit is contained in:
parent
19d0de8a57
commit
88547396eb
|
|
@ -0,0 +1,72 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard import sharded_tensor
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TestShardedTensorElementWiseOps(ShardedTensorTestBase):
|
||||
def _run_sharded_elementwise_ops(self, spec, input_size, op):
|
||||
torch.manual_seed(self.rank)
|
||||
st = sharded_tensor.rand(spec, *input_size)
|
||||
new_st = op(st)
|
||||
local_shard = st.local_tensor()
|
||||
new_st_local_shard = new_st.local_tensor()
|
||||
self.assertEqual(
|
||||
op(local_shard),
|
||||
new_st_local_shard,
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_gelu(self):
|
||||
specs = generate_chunk_sharding_specs_for_test(
|
||||
0
|
||||
) + generate_chunk_sharding_specs_for_test(1)
|
||||
for spec in specs:
|
||||
self._run_sharded_elementwise_ops(spec, [12, 17], torch.nn.functional.gelu)
|
||||
self._run_sharded_elementwise_ops(spec, [18, 21], torch.nn.functional.gelu)
|
||||
self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.gelu)
|
||||
self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.gelu)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_relu(self):
|
||||
specs = generate_chunk_sharding_specs_for_test(
|
||||
0
|
||||
) + generate_chunk_sharding_specs_for_test(1)
|
||||
for spec in specs:
|
||||
self._run_sharded_elementwise_ops(spec, [12, 17], torch.nn.functional.relu)
|
||||
self._run_sharded_elementwise_ops(spec, [18, 21], torch.nn.functional.relu)
|
||||
self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.relu)
|
||||
self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.relu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -21,6 +21,7 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
|
|||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
clone_module_parameter,
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
generate_local_weight_sharding_params_for_test,
|
||||
)
|
||||
|
|
@ -64,8 +65,8 @@ class TestShardedEmbedding(ShardedTensorTestBase):
|
|||
)
|
||||
|
||||
# Copy the weights from local embedding
|
||||
sharded_embedding.weight = torch.nn.Parameter(
|
||||
local_embedding.weight.detach().clone()
|
||||
sharded_embedding.weight = clone_module_parameter(
|
||||
local_embedding, "weight"
|
||||
)
|
||||
|
||||
# Shard the parameter.
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
|
|||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
clone_module_parameter,
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
generate_local_weight_sharding_params_for_test,
|
||||
)
|
||||
|
|
@ -71,8 +72,8 @@ class TestShardedEmbeddingBag(ShardedTensorTestBase):
|
|||
)
|
||||
|
||||
# Copy the weights from local embedding bag.
|
||||
sharded_embedding_bag.weight = torch.nn.Parameter(
|
||||
local_embedding_bag.weight.detach().clone()
|
||||
sharded_embedding_bag.weight = clone_module_parameter(
|
||||
local_embedding_bag, "weight"
|
||||
)
|
||||
|
||||
# Shard the parameter.
|
||||
|
|
|
|||
|
|
@ -1,21 +1,24 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard import shard_parameter
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
empty,
|
||||
_collect_local_shard,
|
||||
_reshard_output,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
EnumerableShardingSpec,
|
||||
ShardMetadata
|
||||
)
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor,
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
|
|
@ -31,6 +34,7 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
|
|||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
clone_module_parameter,
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
generate_local_weight_sharding_params_for_test,
|
||||
)
|
||||
|
|
@ -44,16 +48,17 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
|
||||
|
||||
class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
||||
def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim):
|
||||
def _run_sharded_linear(
|
||||
self, spec, input_size, linear_size, sharded_dim
|
||||
):
|
||||
# Use same seed.
|
||||
torch.manual_seed(0)
|
||||
local_linear = torch.nn.Linear(*linear_size).cuda(self.rank)
|
||||
|
||||
sharded_linear = torch.nn.Linear(*linear_size)
|
||||
|
||||
# Copy the weights and bias from local linear
|
||||
sharded_linear.weight = torch.nn.Parameter(local_linear.weight.detach().clone())
|
||||
sharded_linear.bias = torch.nn.Parameter(local_linear.bias.detach().clone())
|
||||
sharded_linear.weight = clone_module_parameter(local_linear, "weight")
|
||||
sharded_linear.bias = clone_module_parameter(local_linear, "bias")
|
||||
|
||||
# Shard the parameter.
|
||||
shard_parameter(sharded_linear, "weight", spec)
|
||||
|
|
@ -61,6 +66,11 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
|||
# Run sharded computation
|
||||
torch.manual_seed(self.rank) # inputs different on each rank
|
||||
inp = torch.rand(*input_size).cuda(self.rank)
|
||||
reshard_spec = copy.deepcopy(spec)
|
||||
reshard_spec.dim = 0
|
||||
sharded_linear = _collect_local_shard(
|
||||
_reshard_output(sharded_linear, reshard_spec)
|
||||
)
|
||||
sharded_output = sharded_linear(inp)
|
||||
|
||||
# Run local computation
|
||||
|
|
@ -76,6 +86,11 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
|||
sharded_output = torch.nn.functional.linear(
|
||||
inp, sharded_linear.weight, sharded_linear.bias
|
||||
)
|
||||
sharded_output = sharded_output.reshard(reshard_spec).local_tensor()
|
||||
# When local tensor only has one dimension, we increase one more dimension
|
||||
# for reshard. We need to squeeze the # of dimensions manually.
|
||||
if inp.dim() == 1:
|
||||
sharded_output = sharded_output.squeeze(reshard_spec.dim)
|
||||
self.assertEqual(local_output, sharded_output)
|
||||
|
||||
# Compute loss and run backward pass.
|
||||
|
|
@ -84,7 +99,7 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
|||
local_grad = local_linear.weight.grad
|
||||
|
||||
# Verify that both weight and bias in the sharded linear has non-None grad.
|
||||
sharded_weight = sharded_linear.weight.local_shards()[0].tensor
|
||||
sharded_weight = sharded_linear.weight.local_tensor()
|
||||
self.assertNotEqual(sharded_linear.bias.grad, None)
|
||||
self.assertNotEqual(sharded_weight.grad, None)
|
||||
|
||||
|
|
@ -94,9 +109,11 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
|||
local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank
|
||||
)
|
||||
local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos, chunk_size)
|
||||
local_bias_grad = local_linear.bias.grad
|
||||
dist.all_reduce(local_bias_grad)
|
||||
|
||||
# Test backward gradient calculation.
|
||||
self.assertEqual(sharded_linear.bias.grad, local_linear.bias.grad)
|
||||
self.assertEqual(sharded_linear.bias.grad, local_bias_grad)
|
||||
self.assertEqual(sharded_weight.grad, local_grad_narrowed)
|
||||
|
||||
# Test optimizer.
|
||||
|
|
@ -106,9 +123,13 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
|||
self.assertNotEqual(previous, local_linear.weight)
|
||||
previous_sharded_weight = sharded_weight.clone()
|
||||
previous_sharded_bias = sharded_linear.bias.clone()
|
||||
sharded_optim = ShardedOptimizer(dict(named_params_with_sharded_tensor(sharded_linear)), torch.optim.SGD, lr=0.1)
|
||||
sharded_optim = ShardedOptimizer(
|
||||
dict(named_params_with_sharded_tensor(sharded_linear)),
|
||||
torch.optim.SGD,
|
||||
lr=0.1,
|
||||
)
|
||||
sharded_optim.step()
|
||||
sharded_weight = sharded_linear.weight.local_shards()[0].tensor
|
||||
sharded_weight = sharded_linear.weight.local_tensor()
|
||||
local_weight_narrowed = local_linear.weight.narrow(
|
||||
sharded_dim, start_pos, chunk_size
|
||||
)
|
||||
|
|
@ -169,7 +190,7 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
|||
for spec in generate_chunk_sharding_specs_for_test(0):
|
||||
fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
shard_parameter(fc1, "bias", spec)
|
||||
with self.assertRaisesRegex(TypeError, 'input and bias need to be torch.Tensor'):
|
||||
with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'):
|
||||
fc1(torch.rand(10, 10).cuda(self.rank))
|
||||
|
||||
fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,224 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor,
|
||||
)
|
||||
from torch.distributed._shard import (
|
||||
shard_parameter,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
_collect_local_shard,
|
||||
_reshard_output,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
clone_module_parameter,
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
generate_local_weight_sharding_params_for_test,
|
||||
)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TestShardedTensorMegatronLinear(ShardedTensorTestBase):
|
||||
class SimpleMegatronLM(torch.nn.Module):
|
||||
def __init__(self, linear_size, rank=None):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(*linear_size[0])
|
||||
self.gelu = torch.nn.GELU()
|
||||
self.fc2 = torch.nn.Linear(*linear_size[1])
|
||||
if rank:
|
||||
self.fc1.cuda(rank)
|
||||
self.fc2.cuda(rank)
|
||||
|
||||
def forward(self, inp):
|
||||
return self.fc2(self.gelu(self.fc1(inp)))
|
||||
|
||||
def _run_megatron_linear(self, spec, input_size, linear_size):
|
||||
def _weight_override(module_dst, module_src):
|
||||
module_dst.fc1.weight = clone_module_parameter(module_src.fc1, "weight")
|
||||
module_dst.fc1.bias = clone_module_parameter(module_src.fc1, "bias")
|
||||
module_dst.fc2.weight = clone_module_parameter(module_src.fc2, "weight")
|
||||
module_dst.fc2.bias = clone_module_parameter(module_src.fc2, "bias")
|
||||
|
||||
def _shard_parameter(module, spec):
|
||||
shard_parameter(module.fc1, "weight", spec[0])
|
||||
shard_parameter(module.fc2, "weight", spec[1])
|
||||
|
||||
def _get_weight_grad(module):
|
||||
return (module.fc1.weight.grad, module.fc2.weight.grad)
|
||||
|
||||
def _get_bias_grad(module):
|
||||
return (module.fc1.bias.grad, module.fc2.bias.grad)
|
||||
|
||||
def _get_weights(module):
|
||||
return (module.fc1.weight, module.fc2.weight)
|
||||
|
||||
def _get_bias(module):
|
||||
return (module.fc1.bias, module.fc2.bias)
|
||||
|
||||
def _get_weight_local_shard(module):
|
||||
return (
|
||||
module.fc1.weight.local_tensor(),
|
||||
module.fc2.weight.local_tensor(),
|
||||
)
|
||||
|
||||
# Use same seed.
|
||||
torch.manual_seed(0)
|
||||
local_megatron_lm = self.SimpleMegatronLM(linear_size, rank=self.rank).cuda(
|
||||
self.rank
|
||||
)
|
||||
sharded_megatron_lm = self.SimpleMegatronLM(linear_size)
|
||||
_weight_override(sharded_megatron_lm, local_megatron_lm)
|
||||
|
||||
# Shard the parameter. First col-wise sharding and then row-wise
|
||||
_shard_parameter(sharded_megatron_lm, spec)
|
||||
|
||||
# Run sharded computation
|
||||
torch.manual_seed(self.rank) # inputs different on each rank
|
||||
inp = torch.rand(*input_size).cuda(self.rank)
|
||||
reshard_spec = copy.deepcopy(spec[1])
|
||||
reshard_spec.placements.sort(key=lambda placement: placement.rank())
|
||||
reshard_spec.dim = 0
|
||||
|
||||
sharded_megatron_lm = _collect_local_shard(
|
||||
_reshard_output(sharded_megatron_lm, reshard_spec)
|
||||
)
|
||||
sharded_output = sharded_megatron_lm(inp)
|
||||
|
||||
# Run local computation
|
||||
local_output = local_megatron_lm(inp)
|
||||
|
||||
# Verify
|
||||
self.assertEqual(local_output, sharded_output)
|
||||
|
||||
# Compute loss and run backward pass.
|
||||
local_output.sum().backward()
|
||||
sharded_output.sum().backward()
|
||||
(
|
||||
local_weight_grad_fc1,
|
||||
local_weight_grad_fc2,
|
||||
) = _get_weight_grad(local_megatron_lm)
|
||||
local_bias_grad_fc1, local_bias_grad_fc2 = _get_bias_grad(local_megatron_lm)
|
||||
|
||||
# Verify that weights in both layers and biases in the sharded linear has non-None grad.
|
||||
(
|
||||
sharded_weight_fc1,
|
||||
sharded_weight_fc2,
|
||||
) = _get_weight_local_shard(sharded_megatron_lm)
|
||||
bias_grad_fc1, bias_grad_fc2 = _get_bias_grad(sharded_megatron_lm)
|
||||
self.assertNotEqual(sharded_weight_fc1.grad, None)
|
||||
self.assertNotEqual(sharded_weight_fc2.grad, None)
|
||||
self.assertNotEqual(bias_grad_fc1, None)
|
||||
self.assertNotEqual(bias_grad_fc2, None)
|
||||
|
||||
# Shard the local linear's weight grad so that we can compare.
|
||||
dist.all_reduce(local_weight_grad_fc1)
|
||||
dist.all_reduce(local_weight_grad_fc2)
|
||||
dist.all_reduce(local_bias_grad_fc1)
|
||||
dist.all_reduce(local_bias_grad_fc2)
|
||||
local_weight_fc1, local_weight_fc2 = _get_weights(local_megatron_lm)
|
||||
(
|
||||
start_pos_fc1,
|
||||
chunk_size_fc1,
|
||||
) = generate_local_weight_sharding_params_for_test(
|
||||
local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank
|
||||
)
|
||||
local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow(
|
||||
0, start_pos_fc1, chunk_size_fc1
|
||||
)
|
||||
(
|
||||
start_pos_fc2,
|
||||
chunk_size_fc2,
|
||||
) = generate_local_weight_sharding_params_for_test(
|
||||
local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank
|
||||
)
|
||||
local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow(
|
||||
1, start_pos_fc2, chunk_size_fc2
|
||||
)
|
||||
|
||||
# Test backward gradient calculation.
|
||||
self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1)
|
||||
self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2)
|
||||
self.assertEqual(bias_grad_fc1, local_bias_grad_fc1)
|
||||
self.assertEqual(bias_grad_fc2, local_bias_grad_fc2)
|
||||
|
||||
# Test optimizer.
|
||||
bias_fc1, bias_fc2 = _get_bias(sharded_megatron_lm)
|
||||
local_bias_fc1, local_bias_fc2 = _get_bias(local_megatron_lm)
|
||||
self.assertEqual(bias_fc1, local_bias_fc1)
|
||||
self.assertEqual(bias_fc2, local_bias_fc2)
|
||||
self.assertEqual(bias_fc1.grad, local_bias_fc1.grad)
|
||||
self.assertEqual(bias_fc2.grad, local_bias_fc2.grad)
|
||||
previous_sharded_weight_fc1 = sharded_weight_fc1.clone()
|
||||
previous_sharded_weight_fc2 = sharded_weight_fc2.clone()
|
||||
previous_bias_fc1 = bias_fc1.clone()
|
||||
previous_bias_fc2 = bias_fc2.clone()
|
||||
optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1)
|
||||
optim.step()
|
||||
sharded_optim = ShardedOptimizer(
|
||||
dict(named_params_with_sharded_tensor(sharded_megatron_lm)),
|
||||
torch.optim.SGD,
|
||||
lr=0.1,
|
||||
)
|
||||
sharded_optim.step()
|
||||
local_weight_fc1_narrowed = local_weight_fc1.narrow(
|
||||
0, start_pos_fc1, chunk_size_fc1
|
||||
)
|
||||
local_weight_fc2_narrowed = local_weight_fc2.narrow(
|
||||
1, start_pos_fc2, chunk_size_fc2
|
||||
)
|
||||
|
||||
# Test weight value after optimizer.
|
||||
self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size())
|
||||
self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size())
|
||||
self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1)
|
||||
self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2)
|
||||
self.assertEqual(sharded_weight_fc1, local_weight_fc1_narrowed)
|
||||
self.assertEqual(sharded_weight_fc2, local_weight_fc2_narrowed)
|
||||
|
||||
# Test bias value after optimizer.
|
||||
local_bias_fc1, local_bias_fc2 = _get_bias(local_megatron_lm)
|
||||
self.assertNotEqual(previous_bias_fc1, bias_fc1)
|
||||
self.assertEqual(bias_fc1, local_bias_fc1)
|
||||
self.assertNotEqual(previous_bias_fc2, bias_fc2)
|
||||
self.assertEqual(bias_fc2, local_bias_fc2)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_megatron_two_layer_prototype(self):
|
||||
colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)
|
||||
rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)
|
||||
for spec in zip(colwise_sharding_spec, rowwise_sharding_spec):
|
||||
self._run_megatron_linear(spec, [22, 17], [[17, 12], [12, 29]])
|
||||
self._run_megatron_linear(spec, [28, 21], [[21, 11], [11, 29]])
|
||||
self._run_megatron_linear(spec, [37, 23], [[23, 13], [13, 24]])
|
||||
self._run_megatron_linear(spec, [24, 15], [[15, 14], [14, 20]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -12,6 +12,7 @@ from torch.distributed._shard.sharding_spec._internals import (
|
|||
check_tensor,
|
||||
get_split_size,
|
||||
get_chunked_dim_size,
|
||||
get_chunk_sharding_params,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -248,5 +249,32 @@ class TestShardingSpec(TestCase):
|
|||
self.assertEqual(1, get_chunked_dim_size(13, 4, 3))
|
||||
self.assertEqual(0, get_chunked_dim_size(5, 2, 3))
|
||||
|
||||
def test_get_chunk_sharding_params(self):
|
||||
ranks = [
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
"rank:2/cuda:2",
|
||||
"rank:3/cuda:3",
|
||||
]
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=ranks,
|
||||
)
|
||||
result = get_chunk_sharding_params(21, 4, spec, 1)
|
||||
self.assertEqual(6, result[0])
|
||||
self.assertEqual(6, result[1])
|
||||
result = get_chunk_sharding_params(21, 4, spec, 3)
|
||||
self.assertEqual(18, result[0])
|
||||
self.assertEqual(3, result[1])
|
||||
ranks[1], ranks[2] = ranks[2], ranks[1]
|
||||
ranks[0], ranks[3] = ranks[3], ranks[0]
|
||||
spec.placements = ranks
|
||||
result = get_chunk_sharding_params(21, 4, spec, 1)
|
||||
self.assertEqual(12, result[0])
|
||||
self.assertEqual(6, result[1])
|
||||
result = get_chunk_sharding_params(21, 4, spec, 3)
|
||||
self.assertEqual(0, result[0])
|
||||
self.assertEqual(6, result[1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -203,9 +203,11 @@ WINDOWS_BLOCKLIST = [
|
|||
"distributed/elastic/agent/server/test/api_test",
|
||||
"distributed/elastic/multiprocessing/api_test",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_tensor/test_megatron_prototype",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
|
||||
"distributed/_shard/sharded_tensor/test_partial_tensor",
|
||||
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
|
||||
|
|
@ -219,9 +221,11 @@ ROCM_BLOCKLIST = [
|
|||
"distributed/rpc/test_faulty_agent",
|
||||
"distributed/rpc/test_tensorpipe_agent",
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent",
|
||||
"distributed/_shard/sharded_tensor/test_megatron_prototype",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
|
||||
"distributed/_shard/sharded_tensor/test_partial_tensor",
|
||||
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
|
||||
|
|
@ -360,9 +364,11 @@ DISTRIBUTED_TESTS = [
|
|||
"distributed/elastic/utils/distributed_test",
|
||||
"distributed/elastic/multiprocessing/api_test",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_tensor/test_megatron_prototype",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
|
||||
"distributed/_shard/sharded_tensor/test_partial_tensor",
|
||||
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
|
||||
|
|
|
|||
|
|
@ -439,6 +439,12 @@ def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
|
|||
"""
|
||||
Hook a module with local shards collection in the forward pass.
|
||||
|
||||
This API is typically used to convert a sharded representation back to data parallel
|
||||
representation. In particular, it returns the local tensor for this Shard. If the
|
||||
size along the sharding dimension for the local tensor is 1, this dimension is removed
|
||||
from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
|
||||
a local Tensor of size [16] across each rank and not [1, 16] across each rank.
|
||||
|
||||
Args:
|
||||
module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
|
||||
|
||||
|
|
@ -448,7 +454,12 @@ def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
|
|||
|
||||
def hook_func(_module, _input, output):
|
||||
if isinstance(output, ShardedTensor):
|
||||
return output.local_tensor()
|
||||
return output
|
||||
local_tensor = output.local_tensor()
|
||||
# Squeeze the # of dimensions manually.
|
||||
if local_tensor.size(output._sharding_spec.dim) == 1: # type: ignore[attr-defined]
|
||||
local_tensor = local_tensor.squeeze(
|
||||
output._sharding_spec.dim # type: ignore[attr-defined]
|
||||
)
|
||||
return local_tensor
|
||||
module.register_forward_hook(hook_func)
|
||||
return module
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from .init import kaiming_uniform_, normal_, uniform_
|
||||
from .linear import sharded_linear
|
||||
import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
|
||||
|
||||
from .binary_cmp import equal, allclose
|
||||
from .embedding import sharded_embedding
|
||||
from .embedding_bag import sharded_embedding_bag
|
||||
from .binary_cmp import (
|
||||
equal,
|
||||
allclose
|
||||
)
|
||||
from .init import kaiming_uniform_, normal_, uniform_
|
||||
from .linear import sharded_linear
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from torch.distributed.nn.functional import (
|
|||
|
||||
def _handle_col_wise_sharding_base(
|
||||
op_func,
|
||||
sharding_dim_size,
|
||||
col_dim,
|
||||
input,
|
||||
world_size,
|
||||
|
|
@ -39,7 +38,6 @@ def _handle_col_wise_sharding_base(
|
|||
|
||||
Args:
|
||||
op_func: operator which is applied to the input tensor.
|
||||
sharding_dim_size: the max size of the column each rank gets.
|
||||
col_dim: dim of result tensor after the operation.
|
||||
input: tensor to be applied op on.
|
||||
world_size: number of ranks.
|
||||
|
|
@ -90,7 +88,7 @@ def _handle_col_wise_sharding_base(
|
|||
|
||||
# Distribute results to each rank with col rearrangement.
|
||||
output = _result_distribute_with_col_rearrange(
|
||||
results, input, sharding_dim_size, world_size, weight, pg
|
||||
results, input, world_size, weight, pg
|
||||
)
|
||||
|
||||
# transpose the output and return result.
|
||||
|
|
@ -98,7 +96,7 @@ def _handle_col_wise_sharding_base(
|
|||
|
||||
|
||||
def _result_distribute_with_col_rearrange(
|
||||
results, input, sharding_dim_size, world_size, weight, pg
|
||||
results, input, world_size, weight, pg
|
||||
):
|
||||
"""
|
||||
For col-wise sharding of weight, we need to distribute
|
||||
|
|
@ -111,7 +109,6 @@ def _result_distribute_with_col_rearrange(
|
|||
results: results from ops applied to inputs from all ranks.
|
||||
We need to distribute them back to their original ranks.
|
||||
input: tensor to be applied op to.
|
||||
sharding_dim_size: the max size of the column each rank gets.
|
||||
world_size: number of ranks.
|
||||
weight: shareded weight tensor.
|
||||
pg: process group.
|
||||
|
|
@ -119,6 +116,8 @@ def _result_distribute_with_col_rearrange(
|
|||
Return: column rearranged result.
|
||||
"""
|
||||
# Process results and outputs for all2all.
|
||||
sharding_dim = weight._sharding_spec.dim
|
||||
sharding_dim_size = weight.size(sharding_dim)
|
||||
dims = list(results[0].size())
|
||||
dims[0] = sharding_dim_size
|
||||
output = torch.empty(*dims, device=input.device)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
import copy
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
)
|
||||
|
||||
|
||||
def register_elementwise_op(op):
|
||||
@sharded_op_impl(op)
|
||||
def elementwise_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
input = args[0]
|
||||
# Validate types
|
||||
if not isinstance(input, ShardedTensor):
|
||||
raise TypeError("input needs to be a ShardedTensor")
|
||||
local_shards_new = []
|
||||
for local_shard in input.local_shards():
|
||||
local_shards_new.append(Shard(op(local_shard.tensor), local_shard.metadata))
|
||||
# TODO: After a new API for sharded tensor creation, we need to replace this.
|
||||
# https://github.com/pytorch/pytorch/issues/72092
|
||||
new_st = ShardedTensor._init_from_local_shards(
|
||||
local_shards_new, input.size(), process_group=pg
|
||||
)
|
||||
|
||||
# Manually set sharding_spec
|
||||
new_st._sharding_spec = copy.deepcopy(input._sharding_spec)
|
||||
return new_st
|
||||
|
||||
|
||||
register_elementwise_op(torch.nn.functional.gelu)
|
||||
register_elementwise_op(torch.nn.functional.relu)
|
||||
|
|
@ -104,7 +104,7 @@ def sharded_embedding(types, args, kwargs, pg):
|
|||
norm_type = kwargs.get("norm_type")
|
||||
padding_idx = kwargs.get("padding_idx")
|
||||
|
||||
local_shard = weight.local_shards()[0].tensor.contiguous()
|
||||
local_shard = weight.local_tensor().contiguous()
|
||||
sharding_dim = weight._sharding_spec.dim
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
|
|
@ -224,7 +224,6 @@ def _handle_col_wise_sharding(
|
|||
|
||||
output = _handle_col_wise_sharding_base(
|
||||
torch.nn.functional.embedding,
|
||||
weight.size(1),
|
||||
len(input.size()),
|
||||
input,
|
||||
world_size,
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ def sharded_embedding_bag(types, args, kwargs, pg):
|
|||
include_last_offset = kwargs.get("include_last_offset")
|
||||
padding_idx = kwargs.get("padding_idx")
|
||||
|
||||
local_shard = weight.local_shards()[0].tensor.contiguous()
|
||||
local_shard = weight.local_tensor().contiguous()
|
||||
sharding_dim = weight._sharding_spec.dim
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
|
|
@ -319,7 +319,6 @@ def _handle_col_wise_sharding(
|
|||
|
||||
output = _handle_col_wise_sharding_base(
|
||||
torch.nn.functional.embedding_bag,
|
||||
weight.size(1),
|
||||
1,
|
||||
input,
|
||||
world_size,
|
||||
|
|
|
|||
|
|
@ -1,20 +1,29 @@
|
|||
import copy
|
||||
from typing import List, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
from torch.distributed.nn.functional import (
|
||||
all_gather,
|
||||
all_to_all_single,
|
||||
reduce_scatter,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import sharded_op_impl, ShardedTensor
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
_PartialTensor,
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
get_split_size,
|
||||
get_chunked_dim_size,
|
||||
get_chunk_sharding_params,
|
||||
)
|
||||
|
||||
from ._common import (
|
||||
_handle_col_wise_sharding_base,
|
||||
_result_distribute_with_col_rearrange,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -27,6 +36,9 @@ def sharded_linear(types, args, kwargs, pg):
|
|||
1. Supports only sharding of ``weight``.
|
||||
2. Supports only ``ChunkShardingSpec``.
|
||||
3. Supports only a single local shard per rank.
|
||||
4. Tailored for Megatron-LM style model(tensor) parallelism. Further API
|
||||
calls are needed if a fully synced local tensor is needed.
|
||||
Megatron-LM paper link: https://arxiv.org/abs/1909.08053
|
||||
|
||||
Based on the dimension that the weight is sharded on, there are two
|
||||
algorithms:
|
||||
|
|
@ -50,11 +62,13 @@ def sharded_linear(types, args, kwargs, pg):
|
|||
size that we need for the global result which would be (13 x 16)
|
||||
multiplied by (16 x 17). But the final result needs to be aggregated
|
||||
across the rest of the ranks.
|
||||
3. Now the local matmul results are aggregated and shared to all the
|
||||
corresponding ranks using a reduce_scatter operation ensuring each rank
|
||||
3. Here we just return the partial result here. One can call API
|
||||
aggregate_partial_tensor_list to get the aggregated final result.
|
||||
The API uses a reduce_scatter operation ensuring each rank
|
||||
aggregates its own result. This is essentially a sum operation across
|
||||
all the (13 x 17) local computations we did for each rank.
|
||||
4. Finally, we add the bias term locally to the final computation.
|
||||
4. For partial result, we only add 1 / n of the bias term to the partial
|
||||
result. n is # of all GPUs.
|
||||
|
||||
COLWISE SHARDING
|
||||
================
|
||||
|
|
@ -72,53 +86,39 @@ def sharded_linear(types, args, kwargs, pg):
|
|||
2. Next we perform local matmuls by multiplying each input (13 x 17)
|
||||
with the local shard (17 x 4) (transposed). This results in 4 (13 x 4)
|
||||
matrices on each rank.
|
||||
3. Next, we concat these 4 matrices and perform an all2all to share the
|
||||
appropriate (13 x 4) matrices to each rank.
|
||||
4. Now, each rank receives a (13 x 16) matrix which is basically the size
|
||||
of the result we need.
|
||||
3. Next, we stack them into a (4 x 13 x 4) tensor and build a sharded
|
||||
tensor across 4 ranks.
|
||||
4. To merge them into a fully-sync local tensor, one can call API
|
||||
merge_sharded_local_results.
|
||||
This API concat these 4 matrices and perform an all2all to share the
|
||||
appropriate (13 x 4) matrices to each rank. Specifically, each rank
|
||||
receives a (13 x 16) matrix which is basically the size of the result.
|
||||
5. If placements are not in order any appropriate rearrangement of rows
|
||||
are done for the (13 x 16) matrix and finally the bias term is added.
|
||||
"""
|
||||
# Validate input params
|
||||
_validate_linear_op_param(args, kwargs)
|
||||
input = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2]
|
||||
|
||||
# Validate types
|
||||
if not isinstance(input, torch.Tensor) or not isinstance(bias, torch.Tensor):
|
||||
raise TypeError("input and bias need to be torch.Tensor")
|
||||
if not isinstance(weight, ShardedTensor):
|
||||
raise TypeError("weight needs to be ShardedTensor")
|
||||
if len(input.size()) < 1:
|
||||
raise ValueError("Input needs to have at least 1 dim")
|
||||
weight_size = cast(torch.Size, weight.size())
|
||||
if len(weight_size) != 2:
|
||||
raise ValueError("Weight needs to have exactly 2 dims")
|
||||
if len(bias.size()) != 1:
|
||||
raise ValueError("Bias needs to have exactly 1 dim")
|
||||
|
||||
if input.size()[-1] != weight_size[1]:
|
||||
raise ValueError(
|
||||
f"Input dim: {input.size()[-1]} does not match "
|
||||
f"appropriate weight dim: {weight_size[1]}"
|
||||
)
|
||||
if not isinstance(weight._sharding_spec, ChunkShardingSpec):
|
||||
raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!")
|
||||
if len(weight.local_shards()) != 1:
|
||||
raise ValueError("Only one local shard supported!")
|
||||
|
||||
local_shard = weight.local_tensor()
|
||||
local_shard_t = local_shard.t().contiguous()
|
||||
sharding_dim = weight._sharding_spec.dim
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
|
||||
if sharding_dim == 1:
|
||||
return _handle_row_wise_sharding(
|
||||
if sharding_dim == 1 and isinstance(input, torch.Tensor):
|
||||
return _handle_row_wise_sharding_tensor(
|
||||
input, world_size, weight, rank, local_shard_t, bias, pg
|
||||
)
|
||||
elif sharding_dim == 1 and isinstance(input, ShardedTensor):
|
||||
return _handle_row_wise_sharding_sharded_tensor(
|
||||
input, world_size, weight, local_shard_t, bias, pg
|
||||
)
|
||||
elif sharding_dim == 0:
|
||||
return _handle_col_wise_sharding(
|
||||
input, world_size, weight, local_shard_t, bias, pg
|
||||
input, world_size, weight, rank, local_shard_t, bias, pg
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
|
|
@ -126,38 +126,106 @@ def sharded_linear(types, args, kwargs, pg):
|
|||
)
|
||||
|
||||
|
||||
def _handle_col_wise_sharding(input, world_size, weight, local_shard_t, bias, pg):
|
||||
def _validate_linear_op_param(args, kwargs):
|
||||
"""
|
||||
Validate input params of sharded embedding op.
|
||||
|
||||
Args:
|
||||
input: input of the linear layer.
|
||||
weight: shareded weight tensor.
|
||||
kwargs: same as normal Linear.
|
||||
|
||||
Return: None.
|
||||
"""
|
||||
input = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2]
|
||||
|
||||
# Validate types
|
||||
if not isinstance(input, torch.Tensor) and not isinstance(input, ShardedTensor):
|
||||
raise TypeError("input needs to be either torch.Tensor or ShardedTensor")
|
||||
if not isinstance(bias, torch.Tensor):
|
||||
raise TypeError("bias needs to be torch.Tensor")
|
||||
if not isinstance(weight, ShardedTensor):
|
||||
raise TypeError("weight needs to be ShardedTensor")
|
||||
if len(input.size()) < 1: # type: ignore[arg-type]
|
||||
raise ValueError("Input needs to have at least 1 dim")
|
||||
weight_size = cast(torch.Size, weight.size())
|
||||
if len(weight_size) != 2:
|
||||
raise ValueError("Weight needs to have exactly 2 dims")
|
||||
if len(bias.size()) != 1:
|
||||
raise ValueError("Bias needs to have exactly 1 dim")
|
||||
if input.size()[-1] != weight_size[1]: # type: ignore[index]
|
||||
raise ValueError(
|
||||
f"Input dim: {input.size()[-1]} does not match " # type: ignore[index]
|
||||
f"appropriate weight dim: {weight_size[1]}"
|
||||
)
|
||||
if not isinstance(weight._sharding_spec, ChunkShardingSpec):
|
||||
raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!")
|
||||
if len(weight.local_shards()) != 1:
|
||||
raise ValueError("Only one local shard supported!")
|
||||
|
||||
|
||||
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg):
|
||||
"""
|
||||
Entry-point function to handle the logic of col-wise sharding of weight
|
||||
for Linear. (Detailed explanations of the logic can be found in the
|
||||
comment for sharded_linear.)
|
||||
|
||||
When the local tensor only has one dimension, we increase one more dimension
|
||||
for reshard. We need to do squeeze manually to reduce the dimension later-on.
|
||||
|
||||
For example, if we have:
|
||||
input: size[15]
|
||||
weight: size[15, 16]
|
||||
world_size: 4
|
||||
|
||||
In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4]
|
||||
tensor and generate a sharded tenor sharded by dim 1.
|
||||
|
||||
For the rest situations, we just simply concatenate local tensors. No more actions
|
||||
are needed afterward.
|
||||
|
||||
Args:
|
||||
input: matrix to be multiplied with the sharded weight.
|
||||
world_size: number of ranks.
|
||||
weight: shareded weight tensor.
|
||||
rank: # of cuda process.
|
||||
local_shard_t: row-wise shared local weight used for lookup.
|
||||
bias: bias term of linear op.
|
||||
pg: process group.
|
||||
|
||||
Returns: final result of linear operation.
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object which filled with local intermediate results.
|
||||
"""
|
||||
return (
|
||||
_handle_col_wise_sharding_base(
|
||||
torch.matmul,
|
||||
weight.size(0),
|
||||
len(input.size()) - 1,
|
||||
input,
|
||||
world_size,
|
||||
weight,
|
||||
local_shard_t,
|
||||
pg,
|
||||
)
|
||||
+ bias
|
||||
# allgather the inputs first.
|
||||
gathered_inputs = all_gather(input, group=pg)
|
||||
(start_pos, chunk_size) = get_chunk_sharding_params(
|
||||
bias.size(0), world_size, weight._sharding_spec, rank
|
||||
)
|
||||
local_bias = _BiasTensorNarrow.apply(
|
||||
world_size, start_pos, chunk_size, weight, pg, bias
|
||||
)
|
||||
results = [None] * world_size
|
||||
indices = {}
|
||||
for idx, placement in enumerate(weight._sharding_spec.placements):
|
||||
indices[placement.rank()] = idx
|
||||
for i, inp in enumerate(gathered_inputs):
|
||||
results[indices[i]] = inp.matmul(local_shard_t) + local_bias
|
||||
# When the local result only has one dimension, we need to make sure
|
||||
# it does not shard by dim 0. So reshard can work properly.
|
||||
if results[0].dim() == 1: # type: ignore[attr-defined]
|
||||
result = torch.stack(results) # type: ignore[arg-type]
|
||||
else:
|
||||
result = torch.cat(results) # type: ignore[arg-type]
|
||||
return _init_sharded_tensor_from_local_result(
|
||||
weight, result, 0, -1, world_size, pg # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg):
|
||||
def _handle_row_wise_sharding_tensor(
|
||||
input, world_size, weight, rank, local_shard_t, bias, pg
|
||||
):
|
||||
"""
|
||||
Entry-point function to handle the logic of row-wise sharding of weight
|
||||
for Linear. (Detailed explanations of the logic can be found in the
|
||||
|
|
@ -172,7 +240,8 @@ def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bi
|
|||
bias: bias term of linear op.
|
||||
pg: process group.
|
||||
|
||||
Returns: final result of linear operation.
|
||||
Returns:
|
||||
A :class:`_PartialTensor` object which stores the partial local result.
|
||||
"""
|
||||
# alltoall to gather all the appropriate inputs.
|
||||
input_t = input.transpose(0, -1).contiguous()
|
||||
|
|
@ -220,15 +289,146 @@ def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bi
|
|||
gathered_input = gathered_input.transpose(0, -1)
|
||||
|
||||
# Perform local matmuls for all shards
|
||||
shard_size = local_shard_t.size()[0]
|
||||
results = []
|
||||
shard_size = local_shard_t.size()[0]
|
||||
for r in range(world_size):
|
||||
inp = torch.narrow(gathered_input, -1, r * shard_size, shard_size)
|
||||
results.append(inp.matmul(local_shard_t))
|
||||
results.append(
|
||||
inp.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)
|
||||
)
|
||||
|
||||
# Gather all the results appropriately.
|
||||
local_result = torch.empty_like(results[rank])
|
||||
local_result = reduce_scatter(local_result, results, group=pg)
|
||||
# Return the partial local result.
|
||||
return _PartialTensor(torch.cat(results), pg)
|
||||
|
||||
# Return the appropriate local result.
|
||||
return local_result + bias
|
||||
|
||||
def _handle_row_wise_sharding_sharded_tensor(
|
||||
input, world_size, weight, local_shard_t, bias, pg
|
||||
):
|
||||
"""
|
||||
Entry-point function to handle the logic of row-wise sharding of weight
|
||||
for Linear when the input is a sharded tensor. (Detailed explanations
|
||||
of the logic can be found in the comment for sharded_linear.)
|
||||
|
||||
Args:
|
||||
input: matrix to be multiplied with the sharded weight.
|
||||
world_size: number of ranks.
|
||||
weight: shareded weight tensor.
|
||||
local_shard_t: row-wise shared local weight used for lookup.
|
||||
bias: bias term of linear op.
|
||||
pg: process group.
|
||||
|
||||
Returns:
|
||||
A :class:`_PartialTensor` object which stores the partial local result.
|
||||
"""
|
||||
results = []
|
||||
local_shard = input.local_shards()[0].tensor
|
||||
indices = [0] * world_size
|
||||
reaggrance_partial = False
|
||||
for idx, placement in enumerate(input._sharding_spec.placements):
|
||||
indices[placement.rank()] = idx
|
||||
if idx != placement.rank():
|
||||
reaggrance_partial = True
|
||||
|
||||
for tensor in torch.tensor_split(local_shard, world_size):
|
||||
results.append(
|
||||
tensor.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)
|
||||
)
|
||||
if reaggrance_partial:
|
||||
results = [results[idx] for idx in indices]
|
||||
|
||||
# Return the partial local result.
|
||||
return _PartialTensor(torch.cat(results), pg)
|
||||
|
||||
|
||||
def _init_sharded_tensor_from_local_result(
|
||||
sharded_tensor,
|
||||
local_result,
|
||||
tensor_shard_dim,
|
||||
result_shard_dim,
|
||||
world_size,
|
||||
pg,
|
||||
):
|
||||
"""
|
||||
Given a sharded tensor and local_result from an op on top of it. We want
|
||||
to create a new sharded tensor from the local_result so that the the next
|
||||
op can be performed on the basis of the new sharded tensor. This can seen
|
||||
as the last step of the first phase of the Megatron-LM style model(tensor)
|
||||
parallelism.
|
||||
|
||||
Args:
|
||||
sharded_tensor: Sharded tensor which the op was performed on.
|
||||
local_result: A tensor which is from the op performed on the local_shard of
|
||||
the sharded_tensor.
|
||||
tensor_shard_dim: Dim which the tensor is sharded on.
|
||||
result_shard_dim: Dim which the new sharded tensor will be sharded on.
|
||||
world_size: number of ranks.
|
||||
pg (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
|
||||
Return:
|
||||
A :class:`ShardedTensor` object which filled with local intermediate results.
|
||||
"""
|
||||
sharded_weight_metadata = copy.deepcopy(sharded_tensor.local_shards()[0].metadata)
|
||||
current_offsets = [0] * local_result.dim()
|
||||
current_offsets[result_shard_dim] = sharded_weight_metadata.shard_offsets[
|
||||
tensor_shard_dim
|
||||
]
|
||||
global_size = list(local_result.size())
|
||||
global_size[result_shard_dim] = sharded_tensor.size(tensor_shard_dim)
|
||||
local_shard_metadata = ShardMetadata(
|
||||
shard_offsets=current_offsets,
|
||||
shard_sizes=list(local_result.size()),
|
||||
placement=sharded_weight_metadata.placement,
|
||||
)
|
||||
local_shards = [Shard(local_result, local_shard_metadata)]
|
||||
new_st = ShardedTensor._init_from_local_shards(
|
||||
local_shards, tuple(global_size), process_group=pg
|
||||
)
|
||||
|
||||
# Manually set sharding_spec
|
||||
new_st._sharding_spec = copy.deepcopy(sharded_tensor._sharding_spec)
|
||||
new_st._sharding_spec.dim = result_shard_dim
|
||||
return new_st
|
||||
|
||||
|
||||
class _BiasTensorNarrow(Function):
|
||||
"""
|
||||
Since we now return the intermediate results in a col-wise sharding. We
|
||||
need to narrow the bias term in the forward while doing backward, we need
|
||||
to gather all gradients of narrowed bias across all ranks.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, world_size, start_pos, chunk_size, weight, pg, bias):
|
||||
ctx.weight = weight
|
||||
ctx.pg = pg
|
||||
ctx.world_size = world_size
|
||||
return torch.narrow(bias, 0, start_pos, chunk_size)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
results = []
|
||||
for idx in range(ctx.world_size):
|
||||
results.append(grad_output.clone())
|
||||
return (None, None, None, None, None) + (
|
||||
_result_distribute_with_col_rearrange(
|
||||
results, grad_output, ctx.world_size, ctx.weight, ctx.pg
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _BiasTensorPartial(Function):
|
||||
"""
|
||||
Since we now only return partial results in a row-wise sharding. We need to
|
||||
divide the bias term by the world size in the forward while doing backward,
|
||||
we need to skip this division op.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, world_size, bias):
|
||||
ctx.world_size = world_size
|
||||
return torch.div(bias, world_size)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return (None, grad_output)
|
||||
|
|
|
|||
|
|
@ -150,3 +150,30 @@ def get_chunked_dim_size(dim_size, split_size, idx):
|
|||
An int indicating the dim size of the chunk.
|
||||
"""
|
||||
return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0)
|
||||
|
||||
def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank):
|
||||
"""
|
||||
Generate the start pos and offset length for the current rank for
|
||||
chunk sharding.
|
||||
|
||||
Args:
|
||||
sharding_dim_size(int): The dimension length which we shard on.
|
||||
world_size(int): number of ranks.
|
||||
spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`):
|
||||
sharding spec.
|
||||
rank(int): # of cuda process.
|
||||
|
||||
Returns:
|
||||
start_pos(int): start position of sharded tensor on the given rank.
|
||||
chunk_size(int): chunk size of sharded tensor on the given rank.
|
||||
"""
|
||||
split_size = get_split_size(sharding_dim_size, world_size)
|
||||
current_offsets = 0
|
||||
start_pos = current_offsets
|
||||
for idx, placement in enumerate(spec.placements):
|
||||
chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
|
||||
if rank == placement.rank():
|
||||
start_pos = current_offsets
|
||||
break
|
||||
current_offsets += chunk_size
|
||||
return start_pos, chunk_size
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
)
|
||||
|
|
@ -70,3 +71,17 @@ def generate_local_weight_sharding_params_for_test(
|
|||
break
|
||||
current_offsets += chunk_size
|
||||
return start_pos, chunk_size
|
||||
|
||||
|
||||
def clone_module_parameter(module, param_name):
|
||||
"""
|
||||
Clone a parameter from a given existing module.
|
||||
|
||||
Args:
|
||||
module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned.
|
||||
param_name (str): Name of the parameter of ``module`` that needs to be cloned.
|
||||
|
||||
Returns: cloned tensor as :class:`torch.nn.Parameter`.
|
||||
"""
|
||||
tensor = getattr(module, param_name)
|
||||
return torch.nn.Parameter(tensor.detach().clone())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user