mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[shard] Add ReplicatedTensor (#73529)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73529 Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size. ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op): ReplicatedTensor + ReplicatedTensor = ReplicatedTensor ReplicatedTensor + torch.Tensor = torch.Tensor ReplicatedTensor + ShardedTensor = ShardedTensor We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not. TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor. ghstack-source-id: 152064781 Test Plan: test_replicated_tensor Reviewed By: pritamdamania87, fduwjj Differential Revision: D34529374 fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8 (cherry picked from commit 44f4e11e795a1bf330a8108bda256950ca769525)
This commit is contained in:
parent
c9612cddb7
commit
0524b2829a
76
test/distributed/_shard/test_replicated_tensor.py
Normal file
76
test/distributed/_shard/test_replicated_tensor.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.distributed._shard.replicated_tensor import ReplicatedTensor
|
||||
|
||||
|
||||
class TestReplicatedTensor(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_basics(self):
|
||||
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
|
||||
replica_tensor = ReplicatedTensor(local_tensor)
|
||||
print(replica_tensor.process_group)
|
||||
# validate it's a replicated tensor by checking values on all rank
|
||||
validated = replica_tensor.validate()
|
||||
self.assertEqual(validated, True)
|
||||
res = replica_tensor + 2
|
||||
self.assertIsInstance(res, torch.Tensor)
|
||||
self.assertNotIsInstance(res, ReplicatedTensor)
|
||||
self.assertEqual(res, torch.ones(3, 3) * 6)
|
||||
|
||||
# modify local tensor on certain rank, and test if validation raise
|
||||
if self.rank == 2:
|
||||
local_tensor += 3
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'have different values'):
|
||||
replica_tensor.validate()
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_inter_op_replicated_tensor(self):
|
||||
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}")
|
||||
replica_tensor1 = ReplicatedTensor(local_tensor * 4)
|
||||
replica_tensor2 = ReplicatedTensor(local_tensor * 6)
|
||||
|
||||
new_tensor = replica_tensor1 * replica_tensor2
|
||||
self.assertIsInstance(new_tensor, ReplicatedTensor)
|
||||
self.assertEqual(new_tensor, torch.ones(3, 3) * 24)
|
||||
|
||||
# test replicated tensor inter-op with different pgs
|
||||
new_pg = dist.new_group(ranks=[1, 2, 3])
|
||||
replica_tensor_new_group = ReplicatedTensor(local_tensor * 3, process_group=new_pg)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'must be in the same'):
|
||||
replica_tensor_new_group * replica_tensor1
|
||||
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_inter_op_tensor(self):
|
||||
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
|
||||
replica_tensor = ReplicatedTensor(local_tensor)
|
||||
|
||||
local_rand_tensor = torch.randn(3, 3, device=f"cuda:{self.rank}")
|
||||
|
||||
new_tensor = replica_tensor + local_rand_tensor
|
||||
self.assertIsInstance(new_tensor, torch.Tensor)
|
||||
self.assertNotIsInstance(new_tensor, ReplicatedTensor)
|
||||
|
||||
self.assertEqual(new_tensor, local_tensor + local_rand_tensor)
|
||||
|
|
@ -214,6 +214,7 @@ WINDOWS_BLOCKLIST = [
|
|||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/test_replicated_tensor",
|
||||
] + FSDP_TEST
|
||||
|
||||
ROCM_BLOCKLIST = [
|
||||
|
|
@ -233,6 +234,7 @@ ROCM_BLOCKLIST = [
|
|||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/test_replicated_tensor",
|
||||
"test_determination",
|
||||
"test_jit_legacy",
|
||||
"test_type_hints",
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .api import shard_parameter, _shard_tensor
|
||||
from .api import shard_parameter, _shard_tensor, _replicate_tensor
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from torch.distributed._shard.sharded_tensor import (
|
|||
from .sharding_spec import (
|
||||
ShardingSpec,
|
||||
)
|
||||
from .replicated_tensor import ReplicatedTensor
|
||||
|
||||
def _shard_tensor(
|
||||
tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
|
||||
|
|
@ -118,3 +119,20 @@ def shard_parameter(
|
|||
|
||||
# Now we can set the attribute appropriately.
|
||||
setattr(module, param_name, st)
|
||||
|
||||
|
||||
def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
|
||||
"""
|
||||
Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all
|
||||
ranks have the same value.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): the tensor to be marked as replicated.
|
||||
Keyword args:
|
||||
process_group (ProcessGroup, optional): The process group to replicate on.
|
||||
If None, the default process group will be used.
|
||||
Returns:
|
||||
A :class:`ReplicatedTensor` from the given tensor.
|
||||
|
||||
"""
|
||||
return ReplicatedTensor(tensor, process_group=process_group)
|
||||
|
|
|
|||
125
torch/distributed/_shard/replicated_tensor.py
Normal file
125
torch/distributed/_shard/replicated_tensor.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||||
from torch.distributed import distributed_c10d
|
||||
|
||||
|
||||
class ReplicatedTensor(torch.Tensor):
|
||||
"""
|
||||
ReplicatedTensor represents a tensor which is replicated across the `world_size` and
|
||||
has the same value on each rank.
|
||||
|
||||
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together
|
||||
with ShardedTensor/Tensor together to express different types of computation. The
|
||||
inter-op rules defined as (using torch.add as an example op):
|
||||
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
|
||||
ReplicatedTensor + torch.Tensor = torch.Tensor
|
||||
ReplicatedTensor + ShardedTensor = ShardedTensor
|
||||
ReplicatedTensor + other type (i.e. Scalar) = other type
|
||||
|
||||
NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its
|
||||
construction. Although we defined proper inter-op rules to make sure ReplicatedTensor
|
||||
stays the same, there's no enforcement on it (i.e. if you manually modify content on
|
||||
some ranks, the modified value will not automatically get synced to other nodes). If
|
||||
you wish to manually validate tensors are the same across ranks, use `validate()`.
|
||||
|
||||
"""
|
||||
process_group: distributed_c10d.ProcessGroup
|
||||
|
||||
__slots__ = ["process_group"]
|
||||
|
||||
def __new__(cls, data=None, process_group=None):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
r = torch.Tensor._make_subclass(cls, data) # type: ignore[arg-type]
|
||||
r.process_group = ( # type: ignore[attr-defined]
|
||||
process_group
|
||||
if process_group is not None
|
||||
else distributed_c10d._get_default_group()
|
||||
)
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return f"ReplicatedTensor({super(ReplicatedTensor, self).__repr__()})"
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# We will re-dispatch the execution to ShardedTensor __torch_function__
|
||||
# if we find there're ShardedTensor operands. We will also check if args/kwargs
|
||||
# are all replicated tensor operands, we have to do this to ensure we do not
|
||||
# converting results back to ReplicatedTensor if not all operands are replicated.
|
||||
all_replicated = True
|
||||
replicated_pg = None
|
||||
|
||||
def dispatch_arg(arg):
|
||||
nonlocal replicated_pg, all_replicated
|
||||
if isinstance(arg, ShardedTensor):
|
||||
# redispatch to ShardedTensor
|
||||
# TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
|
||||
return arg.__torch_function__(func, types, args, kwargs)
|
||||
if isinstance(arg, ReplicatedTensor):
|
||||
if replicated_pg is None:
|
||||
replicated_pg = arg.process_group
|
||||
elif replicated_pg != arg.process_group:
|
||||
raise RuntimeError(
|
||||
f"ReplicatedTensor operands must be in the same process group "
|
||||
f"in torch function '{func.__name__}', but found at least two "
|
||||
f"ReplicatedTensor operands in different process groups! ")
|
||||
else:
|
||||
all_replicated = False
|
||||
|
||||
for arg in args:
|
||||
dispatch_arg(arg)
|
||||
|
||||
if kwargs is not None:
|
||||
for k, v in kwargs.items():
|
||||
dispatch_arg(v)
|
||||
|
||||
# We cann't do super().__torch_function__() as it implicitly convert the result
|
||||
# back to tensor subclasses, where in our case, we need to control the output type
|
||||
# base on the inter-op rules we defined.
|
||||
with torch._C.DisableTorchFunction():
|
||||
rs = func(*args, **kwargs)
|
||||
if func in get_default_nowrap_functions():
|
||||
return rs
|
||||
if all_replicated and isinstance(rs, torch.Tensor) and not isinstance(rs, cls):
|
||||
# if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
|
||||
# __torch_function__, result is a torch.Tensor, then we convert and return a
|
||||
# ReplicatedTensor according to our inter-op rule
|
||||
rs = rs.as_subclass(cls) # type: ignore[arg-type]
|
||||
# propagate the process_group field to result
|
||||
rs.process_group = replicated_pg # type: ignore[attr-defined]
|
||||
|
||||
return rs
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Validate the ReplicatedTensor is legit by all gathering tensors on all ranks
|
||||
and check to make sure they are the same.
|
||||
|
||||
If there's some ranks with different values, a ValueError will be raised.
|
||||
|
||||
Keyword args:
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
|
||||
Returns:
|
||||
True if validation succeed.
|
||||
"""
|
||||
world_size = dist.get_world_size(self.process_group)
|
||||
current_rank = dist.get_rank(self.process_group)
|
||||
|
||||
tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)]
|
||||
|
||||
dist.all_gather(tensors_on_rank, self, group=self.process_group)
|
||||
# validate and check if all tensors are equal
|
||||
for rank, tensor in enumerate(tensors_on_rank):
|
||||
if not torch.allclose(self, tensor):
|
||||
raise ValueError(
|
||||
f"ReplicatedTensor have different values on rank {current_rank} and {rank}")
|
||||
|
||||
return True
|
||||
|
|
@ -366,7 +366,7 @@ def sharded_op_impl(func):
|
|||
parameters, the function provided will be invoked for that operator.
|
||||
|
||||
Example::
|
||||
>>> @custom_sharded_op(torch.nn.functional.linear)
|
||||
>>> @sharded_op_impl(torch.nn.functional.linear)
|
||||
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
>>>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user