mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Additional unit test for sharded linear. (#70476)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70476 1) Support a single dimension for inputs 2) Test several error cases Partially addresses https://github.com/pytorch/pytorch/issues/65638 ghstack-source-id: 146307607 Test Plan: waitforbuildbot Reviewed By: fduwjj Differential Revision: D33344357 fbshipit-source-id: 4de7a7177452951dbcce76f27441703447609e6f
This commit is contained in:
parent
520226c1bf
commit
96dfded569
|
|
@ -5,8 +5,14 @@ import sys
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._sharded_tensor import (
|
||||
empty,
|
||||
shard_parameter,
|
||||
)
|
||||
from torch.distributed._sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
EnumerableShardingSpec,
|
||||
ShardMetadata
|
||||
)
|
||||
from torch.distributed._sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor,
|
||||
|
|
@ -128,6 +134,12 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
|||
self._run_sharded_linear(spec, [27, 7, 23], [23, 13], 0)
|
||||
self._run_sharded_linear(spec, [100, 12, 4, 15], [15, 14], 0)
|
||||
|
||||
# Test single input dim
|
||||
self._run_sharded_linear(spec, [17], [17, 12], 0)
|
||||
self._run_sharded_linear(spec, [21], [21, 11], 0)
|
||||
self._run_sharded_linear(spec, [23], [23, 13], 0)
|
||||
self._run_sharded_linear(spec, [15], [15, 14], 0)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
|
|
@ -145,6 +157,91 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
|
|||
self._run_sharded_linear(spec, [10, 5, 19], [19, 11], 1)
|
||||
self._run_sharded_linear(spec, [12, 15, 10, 21], [21, 11], 1)
|
||||
|
||||
# Test single input dim
|
||||
self._run_sharded_linear(spec, [16], [16, 11], 1)
|
||||
self._run_sharded_linear(spec, [19], [19, 11], 1)
|
||||
self._run_sharded_linear(spec, [21], [21, 11], 1)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_linear_errors(self):
|
||||
for spec in generate_chunk_sharding_specs_for_test(0):
|
||||
fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
shard_parameter(fc1, "bias", spec)
|
||||
with self.assertRaisesRegex(TypeError, 'input and bias need to be torch.Tensor'):
|
||||
fc1(torch.rand(10, 10).cuda(self.rank))
|
||||
|
||||
fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
shard_parameter(fc2, "weight", spec)
|
||||
with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'):
|
||||
fc2(torch.tensor(1).cuda(self.rank))
|
||||
|
||||
fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
fc3.weight = torch.nn.Parameter(torch.rand(10, 10, 10).cuda(self.rank))
|
||||
shard_parameter(fc3, "weight", spec)
|
||||
with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'):
|
||||
fc3(torch.rand(10, 10).cuda(self.rank))
|
||||
|
||||
fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
|
||||
shard_parameter(fc4, "weight", spec)
|
||||
with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'):
|
||||
fc4(torch.rand(10, 10).cuda(self.rank))
|
||||
|
||||
fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
|
||||
shard_parameter(fc5, "weight", spec)
|
||||
with self.assertRaisesRegex(ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'):
|
||||
fc5(torch.rand(20, 10, 13).cuda(self.rank))
|
||||
|
||||
fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
del fc6.weight
|
||||
enumerable_spec = EnumerableShardingSpec([
|
||||
ShardMetadata(
|
||||
shard_offsets=[0, 0],
|
||||
shard_sizes=[5, 5],
|
||||
placement="rank:0/cuda:0",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[0, 5],
|
||||
shard_sizes=[5, 5],
|
||||
placement="rank:1/cuda:1",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[5, 0],
|
||||
shard_sizes=[5, 5],
|
||||
placement="rank:2/cuda:2",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[5, 5],
|
||||
shard_sizes=[5, 5],
|
||||
placement="rank:3/cuda:3",
|
||||
)
|
||||
])
|
||||
|
||||
fc6.weight = empty(enumerable_spec, 10, 10)
|
||||
with self.assertRaisesRegex(ValueError, 'Only ChunkShardingSpec supported for ShardedTensor ops!'):
|
||||
fc6(torch.rand(10, 10).cuda(self.rank))
|
||||
|
||||
fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
|
||||
multiple_local_shard_spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
"rank:0/cuda:0",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
"rank:1/cuda:1",
|
||||
"rank:2/cuda:2",
|
||||
"rank:2/cuda:2",
|
||||
"rank:3/cuda:3",
|
||||
"rank:3/cuda:3",
|
||||
],
|
||||
)
|
||||
del fc7.weight
|
||||
fc7.weight = empty(multiple_local_shard_spec, 80, 10)
|
||||
with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'):
|
||||
fc7(torch.rand(10, 10).cuda(self.rank))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -90,8 +90,8 @@ def sharded_linear(types, args, kwargs, pg):
|
|||
raise TypeError("input and bias need to be torch.Tensor")
|
||||
if not isinstance(weight, ShardedTensor):
|
||||
raise TypeError("weight needs to be ShardedTensor")
|
||||
if len(input.size()) < 2:
|
||||
raise ValueError("Input needs to have at least 2 dims")
|
||||
if len(input.size()) < 1:
|
||||
raise ValueError("Input needs to have at least 1 dim")
|
||||
weight_size = cast(torch.Size, weight.size())
|
||||
if len(weight_size) != 2:
|
||||
raise ValueError("Weight needs to have exactly 2 dims")
|
||||
|
|
@ -100,7 +100,7 @@ def sharded_linear(types, args, kwargs, pg):
|
|||
|
||||
if input.size()[-1] != weight_size[1]:
|
||||
raise ValueError(
|
||||
f"Input dim: {input.size()[1]} does not match "
|
||||
f"Input dim: {input.size()[-1]} does not match "
|
||||
f"appropriate weight dim: {weight_size[1]}"
|
||||
)
|
||||
if not isinstance(weight._sharding_spec, ChunkShardingSpec):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user