[shard] Add ReplicatedTensor (#73529)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73529

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
ghstack-source-id: 152064781

Test Plan: test_replicated_tensor

Reviewed By: pritamdamania87, fduwjj

Differential Revision: D34529374

fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
(cherry picked from commit 44f4e11e795a1bf330a8108bda256950ca769525)
This commit is contained in:
Wanchao Liang 2022-03-24 05:36:04 -07:00 committed by PyTorch MergeBot
parent c9612cddb7
commit 0524b2829a
6 changed files with 223 additions and 2 deletions

View File

@ -0,0 +1,76 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
from torch.distributed._shard.replicated_tensor import ReplicatedTensor
class TestReplicatedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_replicated_tensor_basics(self):
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
replica_tensor = ReplicatedTensor(local_tensor)
print(replica_tensor.process_group)
# validate it's a replicated tensor by checking values on all rank
validated = replica_tensor.validate()
self.assertEqual(validated, True)
res = replica_tensor + 2
self.assertIsInstance(res, torch.Tensor)
self.assertNotIsInstance(res, ReplicatedTensor)
self.assertEqual(res, torch.ones(3, 3) * 6)
# modify local tensor on certain rank, and test if validation raise
if self.rank == 2:
local_tensor += 3
with self.assertRaisesRegex(ValueError, 'have different values'):
replica_tensor.validate()
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_replicated_tensor_inter_op_replicated_tensor(self):
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}")
replica_tensor1 = ReplicatedTensor(local_tensor * 4)
replica_tensor2 = ReplicatedTensor(local_tensor * 6)
new_tensor = replica_tensor1 * replica_tensor2
self.assertIsInstance(new_tensor, ReplicatedTensor)
self.assertEqual(new_tensor, torch.ones(3, 3) * 24)
# test replicated tensor inter-op with different pgs
new_pg = dist.new_group(ranks=[1, 2, 3])
replica_tensor_new_group = ReplicatedTensor(local_tensor * 3, process_group=new_pg)
with self.assertRaisesRegex(RuntimeError, 'must be in the same'):
replica_tensor_new_group * replica_tensor1
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_replicated_tensor_inter_op_tensor(self):
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
replica_tensor = ReplicatedTensor(local_tensor)
local_rand_tensor = torch.randn(3, 3, device=f"cuda:{self.rank}")
new_tensor = replica_tensor + local_rand_tensor
self.assertIsInstance(new_tensor, torch.Tensor)
self.assertNotIsInstance(new_tensor, ReplicatedTensor)
self.assertEqual(new_tensor, local_tensor + local_rand_tensor)

View File

@ -214,6 +214,7 @@ WINDOWS_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_replicated_tensor",
] + FSDP_TEST
ROCM_BLOCKLIST = [
@ -233,6 +234,7 @@ ROCM_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_replicated_tensor",
"test_determination",
"test_jit_legacy",
"test_type_hints",

View File

@ -1 +1 @@
from .api import shard_parameter, _shard_tensor
from .api import shard_parameter, _shard_tensor, _replicate_tensor

View File

@ -7,6 +7,7 @@ from torch.distributed._shard.sharded_tensor import (
from .sharding_spec import (
ShardingSpec,
)
from .replicated_tensor import ReplicatedTensor
def _shard_tensor(
tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
@ -118,3 +119,20 @@ def shard_parameter(
# Now we can set the attribute appropriately.
setattr(module, param_name, st)
def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
"""
Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all
ranks have the same value.
Args:
tensor (:class:`torch.Tensor`): the tensor to be marked as replicated.
Keyword args:
process_group (ProcessGroup, optional): The process group to replicate on.
If None, the default process group will be used.
Returns:
A :class:`ReplicatedTensor` from the given tensor.
"""
return ReplicatedTensor(tensor, process_group=process_group)

View File

@ -0,0 +1,125 @@
import torch
import torch.distributed as dist
from torch.overrides import get_default_nowrap_functions
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed import distributed_c10d
class ReplicatedTensor(torch.Tensor):
"""
ReplicatedTensor represents a tensor which is replicated across the `world_size` and
has the same value on each rank.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together
with ShardedTensor/Tensor together to express different types of computation. The
inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
ReplicatedTensor + other type (i.e. Scalar) = other type
NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its
construction. Although we defined proper inter-op rules to make sure ReplicatedTensor
stays the same, there's no enforcement on it (i.e. if you manually modify content on
some ranks, the modified value will not automatically get synced to other nodes). If
you wish to manually validate tensors are the same across ranks, use `validate()`.
"""
process_group: distributed_c10d.ProcessGroup
__slots__ = ["process_group"]
def __new__(cls, data=None, process_group=None):
if data is None:
data = torch.empty(0)
r = torch.Tensor._make_subclass(cls, data) # type: ignore[arg-type]
r.process_group = ( # type: ignore[attr-defined]
process_group
if process_group is not None
else distributed_c10d._get_default_group()
)
return r
def __repr__(self):
return f"ReplicatedTensor({super(ReplicatedTensor, self).__repr__()})"
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# We will re-dispatch the execution to ShardedTensor __torch_function__
# if we find there're ShardedTensor operands. We will also check if args/kwargs
# are all replicated tensor operands, we have to do this to ensure we do not
# converting results back to ReplicatedTensor if not all operands are replicated.
all_replicated = True
replicated_pg = None
def dispatch_arg(arg):
nonlocal replicated_pg, all_replicated
if isinstance(arg, ShardedTensor):
# redispatch to ShardedTensor
# TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
return arg.__torch_function__(func, types, args, kwargs)
if isinstance(arg, ReplicatedTensor):
if replicated_pg is None:
replicated_pg = arg.process_group
elif replicated_pg != arg.process_group:
raise RuntimeError(
f"ReplicatedTensor operands must be in the same process group "
f"in torch function '{func.__name__}', but found at least two "
f"ReplicatedTensor operands in different process groups! ")
else:
all_replicated = False
for arg in args:
dispatch_arg(arg)
if kwargs is not None:
for k, v in kwargs.items():
dispatch_arg(v)
# We cann't do super().__torch_function__() as it implicitly convert the result
# back to tensor subclasses, where in our case, we need to control the output type
# base on the inter-op rules we defined.
with torch._C.DisableTorchFunction():
rs = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return rs
if all_replicated and isinstance(rs, torch.Tensor) and not isinstance(rs, cls):
# if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
# __torch_function__, result is a torch.Tensor, then we convert and return a
# ReplicatedTensor according to our inter-op rule
rs = rs.as_subclass(cls) # type: ignore[arg-type]
# propagate the process_group field to result
rs.process_group = replicated_pg # type: ignore[attr-defined]
return rs
def validate(self) -> bool:
"""
Validate the ReplicatedTensor is legit by all gathering tensors on all ranks
and check to make sure they are the same.
If there's some ranks with different values, a ValueError will be raised.
Keyword args:
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
True if validation succeed.
"""
world_size = dist.get_world_size(self.process_group)
current_rank = dist.get_rank(self.process_group)
tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)]
dist.all_gather(tensors_on_rank, self, group=self.process_group)
# validate and check if all tensors are equal
for rank, tensor in enumerate(tensors_on_rank):
if not torch.allclose(self, tensor):
raise ValueError(
f"ReplicatedTensor have different values on rank {current_rank} and {rank}")
return True

View File

@ -366,7 +366,7 @@ def sharded_op_impl(func):
parameters, the function provided will be invoked for that operator.
Example::
>>> @custom_sharded_op(torch.nn.functional.linear)
>>> @sharded_op_impl(torch.nn.functional.linear)
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
>>> ....
>>>