From 96dfded5697e451b54f113f99b6d0da6f6af500d Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Wed, 19 Jan 2022 17:19:32 -0800 Subject: [PATCH] Additional unit test for sharded linear. (#70476) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70476 1) Support a single dimension for inputs 2) Test several error cases Partially addresses https://github.com/pytorch/pytorch/issues/65638 ghstack-source-id: 146307607 Test Plan: waitforbuildbot Reviewed By: fduwjj Differential Revision: D33344357 fbshipit-source-id: 4de7a7177452951dbcce76f27441703447609e6f --- .../_sharded_tensor/ops/test_linear.py | 97 +++++++++++++++++++ .../distributed/_sharded_tensor/ops/linear.py | 6 +- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/test/distributed/_sharded_tensor/ops/test_linear.py b/test/distributed/_sharded_tensor/ops/test_linear.py index c6659138b95..352dd11707f 100644 --- a/test/distributed/_sharded_tensor/ops/test_linear.py +++ b/test/distributed/_sharded_tensor/ops/test_linear.py @@ -5,8 +5,14 @@ import sys import torch import torch.distributed as dist from torch.distributed._sharded_tensor import ( + empty, shard_parameter, ) +from torch.distributed._sharding_spec import ( + ChunkShardingSpec, + EnumerableShardingSpec, + ShardMetadata +) from torch.distributed._sharded_optim import ( ShardedOptimizer, named_params_with_sharded_tensor, @@ -128,6 +134,12 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase): self._run_sharded_linear(spec, [27, 7, 23], [23, 13], 0) self._run_sharded_linear(spec, [100, 12, 4, 15], [15, 14], 0) + # Test single input dim + self._run_sharded_linear(spec, [17], [17, 12], 0) + self._run_sharded_linear(spec, [21], [21, 11], 0) + self._run_sharded_linear(spec, [23], [23, 13], 0) + self._run_sharded_linear(spec, [15], [15, 14], 0) + @with_comms(init_rpc=False) @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() @@ -145,6 +157,91 @@ class TestShardedTensorOpsLinear(ShardedTensorTestBase): self._run_sharded_linear(spec, [10, 5, 19], [19, 11], 1) self._run_sharded_linear(spec, [12, 15, 10, 21], [21, 11], 1) + # Test single input dim + self._run_sharded_linear(spec, [16], [16, 11], 1) + self._run_sharded_linear(spec, [19], [19, 11], 1) + self._run_sharded_linear(spec, [21], [21, 11], 1) + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_sharded_linear_errors(self): + for spec in generate_chunk_sharding_specs_for_test(0): + fc1 = torch.nn.Linear(10, 10).cuda(self.rank) + shard_parameter(fc1, "bias", spec) + with self.assertRaisesRegex(TypeError, 'input and bias need to be torch.Tensor'): + fc1(torch.rand(10, 10).cuda(self.rank)) + + fc2 = torch.nn.Linear(10, 10).cuda(self.rank) + shard_parameter(fc2, "weight", spec) + with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): + fc2(torch.tensor(1).cuda(self.rank)) + + fc3 = torch.nn.Linear(10, 10).cuda(self.rank) + fc3.weight = torch.nn.Parameter(torch.rand(10, 10, 10).cuda(self.rank)) + shard_parameter(fc3, "weight", spec) + with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): + fc3(torch.rand(10, 10).cuda(self.rank)) + + fc4 = torch.nn.Linear(10, 10).cuda(self.rank) + fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) + shard_parameter(fc4, "weight", spec) + with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): + fc4(torch.rand(10, 10).cuda(self.rank)) + + fc5 = torch.nn.Linear(7, 10).cuda(self.rank) + shard_parameter(fc5, "weight", spec) + with self.assertRaisesRegex(ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): + fc5(torch.rand(20, 10, 13).cuda(self.rank)) + + fc6 = torch.nn.Linear(10, 10).cuda(self.rank) + del fc6.weight + enumerable_spec = EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[5, 5], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 5], + shard_sizes=[5, 5], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_sizes=[5, 5], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_sizes=[5, 5], + placement="rank:3/cuda:3", + ) + ]) + + fc6.weight = empty(enumerable_spec, 10, 10) + with self.assertRaisesRegex(ValueError, 'Only ChunkShardingSpec supported for ShardedTensor ops!'): + fc6(torch.rand(10, 10).cuda(self.rank)) + + fc7 = torch.nn.Linear(10, 80).cuda(self.rank) + multiple_local_shard_spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:2/cuda:2", + "rank:3/cuda:3", + "rank:3/cuda:3", + ], + ) + del fc7.weight + fc7.weight = empty(multiple_local_shard_spec, 80, 10) + with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): + fc7(torch.rand(10, 10).cuda(self.rank)) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_sharded_tensor/ops/linear.py b/torch/distributed/_sharded_tensor/ops/linear.py index 464490d5310..6ecec716352 100644 --- a/torch/distributed/_sharded_tensor/ops/linear.py +++ b/torch/distributed/_sharded_tensor/ops/linear.py @@ -90,8 +90,8 @@ def sharded_linear(types, args, kwargs, pg): raise TypeError("input and bias need to be torch.Tensor") if not isinstance(weight, ShardedTensor): raise TypeError("weight needs to be ShardedTensor") - if len(input.size()) < 2: - raise ValueError("Input needs to have at least 2 dims") + if len(input.size()) < 1: + raise ValueError("Input needs to have at least 1 dim") weight_size = cast(torch.Size, weight.size()) if len(weight_size) != 2: raise ValueError("Weight needs to have exactly 2 dims") @@ -100,7 +100,7 @@ def sharded_linear(types, args, kwargs, pg): if input.size()[-1] != weight_size[1]: raise ValueError( - f"Input dim: {input.size()[1]} does not match " + f"Input dim: {input.size()[-1]} does not match " f"appropriate weight dim: {weight_size[1]}" ) if not isinstance(weight._sharding_spec, ChunkShardingSpec):