Additional unit test for sharded linear. (#70476)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70476

1) Support a single dimension for inputs
2) Test several error cases

Partially addresses https://github.com/pytorch/pytorch/issues/65638
ghstack-source-id: 146307607

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33344357

fbshipit-source-id: 4de7a7177452951dbcce76f27441703447609e6f
This commit is contained in:
Pritam Damania 2022-01-19 17:19:32 -08:00 committed by Facebook GitHub Bot
parent 520226c1bf
commit 96dfded569
2 changed files with 100 additions and 3 deletions

View File

@ -5,8 +5,14 @@ import sys
import torch
import torch.distributed as dist
from torch.distributed._sharded_tensor import (
empty,
shard_parameter,
)
from torch.distributed._sharding_spec import (
ChunkShardingSpec,
EnumerableShardingSpec,
ShardMetadata
)
from torch.distributed._sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor,
@ -128,6 +134,12 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
self._run_sharded_linear(spec, [27, 7, 23], [23, 13], 0)
self._run_sharded_linear(spec, [100, 12, 4, 15], [15, 14], 0)
# Test single input dim
self._run_sharded_linear(spec, [17], [17, 12], 0)
self._run_sharded_linear(spec, [21], [21, 11], 0)
self._run_sharded_linear(spec, [23], [23, 13], 0)
self._run_sharded_linear(spec, [15], [15, 14], 0)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
@ -145,6 +157,91 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase):
self._run_sharded_linear(spec, [10, 5, 19], [19, 11], 1)
self._run_sharded_linear(spec, [12, 15, 10, 21], [21, 11], 1)
# Test single input dim
self._run_sharded_linear(spec, [16], [16, 11], 1)
self._run_sharded_linear(spec, [19], [19, 11], 1)
self._run_sharded_linear(spec, [21], [21, 11], 1)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_linear_errors(self):
for spec in generate_chunk_sharding_specs_for_test(0):
fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
shard_parameter(fc1, "bias", spec)
with self.assertRaisesRegex(TypeError, 'input and bias need to be torch.Tensor'):
fc1(torch.rand(10, 10).cuda(self.rank))
fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
shard_parameter(fc2, "weight", spec)
with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'):
fc2(torch.tensor(1).cuda(self.rank))
fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
fc3.weight = torch.nn.Parameter(torch.rand(10, 10, 10).cuda(self.rank))
shard_parameter(fc3, "weight", spec)
with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'):
fc3(torch.rand(10, 10).cuda(self.rank))
fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
shard_parameter(fc4, "weight", spec)
with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'):
fc4(torch.rand(10, 10).cuda(self.rank))
fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
shard_parameter(fc5, "weight", spec)
with self.assertRaisesRegex(ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'):
fc5(torch.rand(20, 10, 13).cuda(self.rank))
fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
del fc6.weight
enumerable_spec = EnumerableShardingSpec([
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[5, 5],
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[0, 5],
shard_sizes=[5, 5],
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[5, 0],
shard_sizes=[5, 5],
placement="rank:2/cuda:2",
),
ShardMetadata(
shard_offsets=[5, 5],
shard_sizes=[5, 5],
placement="rank:3/cuda:3",
)
])
fc6.weight = empty(enumerable_spec, 10, 10)
with self.assertRaisesRegex(ValueError, 'Only ChunkShardingSpec supported for ShardedTensor ops!'):
fc6(torch.rand(10, 10).cuda(self.rank))
fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
multiple_local_shard_spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:2/cuda:2",
"rank:3/cuda:3",
"rank:3/cuda:3",
],
)
del fc7.weight
fc7.weight = empty(multiple_local_shard_spec, 80, 10)
with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'):
fc7(torch.rand(10, 10).cuda(self.rank))
if __name__ == "__main__":
run_tests()

View File

@ -90,8 +90,8 @@ def sharded_linear(types, args, kwargs, pg):
raise TypeError("input and bias need to be torch.Tensor")
if not isinstance(weight, ShardedTensor):
raise TypeError("weight needs to be ShardedTensor")
if len(input.size()) < 2:
raise ValueError("Input needs to have at least 2 dims")
if len(input.size()) < 1:
raise ValueError("Input needs to have at least 1 dim")
weight_size = cast(torch.Size, weight.size())
if len(weight_size) != 2:
raise ValueError("Weight needs to have exactly 2 dims")
@ -100,7 +100,7 @@ def sharded_linear(types, args, kwargs, pg):
if input.size()[-1] != weight_size[1]:
raise ValueError(
f"Input dim: {input.size()[1]} does not match "
f"Input dim: {input.size()[-1]} does not match "
f"appropriate weight dim: {weight_size[1]}"
)
if not isinstance(weight._sharding_spec, ChunkShardingSpec):