mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR canonicalize the detach callsite to only call the detach from `distribute_tensor`. Change other callsite to view_as and remove the tensor constructor detach call This is so that we don't detach local tensor for every op run when rewrapping the DTensor Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239 Approved by: https://github.com/albanD
287 lines
12 KiB
Python
287 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import itertools
|
|
|
|
import torch
|
|
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
|
|
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
|
|
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
class RedistributeTest(DTensorTestBase):
|
|
@with_comms
|
|
def test_shard_to_replicate_forward_backward(self):
|
|
# 1) test shard -> replicate forward
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
replica_spec = [Replicate()]
|
|
|
|
input_sizes_and_shard_dim = [
|
|
((self.world_size * 3, 3), 0),
|
|
((self.world_size * 3 + 1, 3), 0),
|
|
((self.world_size * 3 + 2, 3), 0),
|
|
((3, self.world_size * 3), 1),
|
|
((3, self.world_size * 3 + 1), 1),
|
|
((3, self.world_size * 3 + 2), 1),
|
|
]
|
|
|
|
for input_size, shard_dim in input_sizes_and_shard_dim:
|
|
shard_spec = [Shard(shard_dim)]
|
|
expected_tensor = torch.randn(
|
|
input_size, device=self.device_type, requires_grad=True
|
|
)
|
|
dtensor = distribute_tensor(expected_tensor, device_mesh, shard_spec)
|
|
reshard_dtensor = dtensor.redistribute(device_mesh, replica_spec)
|
|
self.assertEqual(reshard_dtensor.size(), torch.Size(input_size))
|
|
self.assertEqual(expected_tensor, reshard_dtensor.to_local())
|
|
|
|
# 2) test shard -> replicate backward:
|
|
# should give gradient as shard
|
|
grad_output = torch.ones_like(reshard_dtensor)
|
|
reshard_dtensor.backward(grad_output)
|
|
grad_input = dtensor.grad
|
|
self.assertEqual(grad_input.placements, shard_spec)
|
|
self.assertEqual(
|
|
grad_input.to_local(), torch.ones(dtensor.to_local().size())
|
|
)
|
|
|
|
@with_comms
|
|
def test_replicate_to_replicate_forward_backward(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
replica_spec = [Replicate()]
|
|
local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
|
|
# 1) test replicate -> replicate forward
|
|
replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec)
|
|
reshard_replica_tensor = replica_tensor.redistribute(device_mesh, replica_spec)
|
|
self.assertEqual(replica_tensor.size(), local_tensor.size())
|
|
self.assertEqual(replica_tensor, reshard_replica_tensor)
|
|
|
|
# 2) test replicate -> replicate backward:
|
|
# should give gradient as replicate
|
|
grad_output = torch.ones_like(reshard_replica_tensor)
|
|
reshard_replica_tensor.backward(grad_output)
|
|
grad_input = replica_tensor.grad
|
|
self.assertEqual(grad_input.placements, replica_spec)
|
|
self.assertEqual(grad_input.to_local(), torch.ones(12, 3))
|
|
|
|
@with_comms
|
|
def test_replicate_to_shard_forward_backward(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
replica_spec = [Replicate()]
|
|
|
|
input_sizes_and_shard_dim = [
|
|
((self.world_size * 3, 3), 0),
|
|
((self.world_size * 3 + 1, 3), 0),
|
|
((self.world_size * 3 + 2, 3), 0),
|
|
((3, self.world_size * 3), 1),
|
|
((3, self.world_size * 3 + 1), 1),
|
|
((3, self.world_size * 3 + 2), 1),
|
|
]
|
|
for input_size, shard_dim in input_sizes_and_shard_dim:
|
|
shard_spec = [Shard(shard_dim)]
|
|
# 1) test replicate -> shard forward
|
|
local_replica = torch.randn(
|
|
input_size, device=self.device_type, requires_grad=True
|
|
)
|
|
splitted_list = list(
|
|
torch.chunk(local_replica, self.world_size, dim=shard_dim)
|
|
)
|
|
|
|
# make local tensor as the element of the corresponding chunked list
|
|
local_tensor = splitted_list[self.rank]
|
|
replica_tensor = distribute_tensor(local_replica, device_mesh, replica_spec)
|
|
reshard_tensor = replica_tensor.redistribute(device_mesh, shard_spec)
|
|
self.assertEqual(reshard_tensor.size(), replica_tensor.size())
|
|
self.assertEqual(reshard_tensor.placements, shard_spec)
|
|
self.assertEqual(reshard_tensor.to_local(), local_tensor)
|
|
|
|
# 2) test replicate -> shard backward:
|
|
# should give gradient as replicate
|
|
grad_output = torch.ones_like(reshard_tensor)
|
|
reshard_tensor.backward(grad_output)
|
|
grad_input = replica_tensor.grad
|
|
self.assertEqual(grad_input.placements, replica_spec)
|
|
self.assertEqual(grad_input.to_local(), torch.ones(input_size))
|
|
|
|
@with_comms
|
|
def test_partial_to_replicate_forward_backward(self):
|
|
# Although we don't allow user to reshard to produce a partial
|
|
# placement (i.e. user can't reshard to partial), we do allow
|
|
# replicate to partial internally, and also partial to replicate
|
|
# backward should work as expected
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
partial_local = torch.randn(12, 3, device=self.device_type, requires_grad=True)
|
|
partial_spec = [_Partial()]
|
|
replica_spec = [Replicate()]
|
|
# test partial -> replicate, which trigger all_reduce
|
|
partial_tensor = DTensor.from_local(partial_local, device_mesh, partial_spec)
|
|
global_partial_tensor = partial_tensor.redistribute(device_mesh, replica_spec)
|
|
|
|
self.assertEqual(partial_tensor.size(), partial_local.size())
|
|
self.assertEqual(
|
|
partial_local * self.world_size, global_partial_tensor.to_local()
|
|
)
|
|
|
|
# test backward to have replicate grad on partial
|
|
global_partial_tensor.backward(torch.ones_like(global_partial_tensor))
|
|
self.assertIsNotNone(partial_local.grad)
|
|
if device_mesh.get_rank() == 0:
|
|
self.assertEqual(partial_local.grad, torch.ones_like(partial_local))
|
|
|
|
@with_comms
|
|
def test_replicate_to_partial(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
|
|
partial_spec = _Partial()
|
|
replica_spec = Replicate()
|
|
# 1) test replicate -> partial forward
|
|
replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec])
|
|
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to _Partial"):
|
|
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
|
|
|
|
from torch.distributed._tensor.redistribute import Redistribute
|
|
|
|
partial_tensor = Redistribute.apply(replica_tensor, device_mesh, [partial_spec])
|
|
self.assertEqual(partial_tensor.size(), local_tensor.size())
|
|
# test it successfully zero out the contents on other ranks
|
|
if self.rank == 0:
|
|
self.assertEqual(replica_tensor.to_local(), partial_tensor.to_local())
|
|
else:
|
|
self.assertEqual(partial_tensor.to_local(), torch.zeros_like(local_tensor))
|
|
|
|
# replicate to partial on sub groups
|
|
local_tensor = torch.randn(12, 3, device=self.device_type)
|
|
device_mesh = DeviceMesh(
|
|
self.device_type,
|
|
torch.arange(self.world_size).reshape(self.world_size // 2, 2),
|
|
)
|
|
# 1) test replicate -> partial on 2d-mesh subgroups
|
|
replica_tensor = distribute_tensor(
|
|
local_tensor, device_mesh, [replica_spec, replica_spec]
|
|
)
|
|
partial_tensor = Redistribute.apply(
|
|
replica_tensor, device_mesh, [partial_spec, partial_spec]
|
|
)
|
|
self.assertEqual(partial_tensor.size(), local_tensor.size())
|
|
|
|
if self.rank != 3:
|
|
# replicate to partial should only zero out rank 3, and leave
|
|
# rank 0/2 (rank0 on mesh dim 1) and 0, 1 (rank0 on mesh dim 1) un-touched
|
|
self.assertEqual(replica_tensor.to_local(), partial_tensor.to_local())
|
|
else:
|
|
self.assertEqual(replica_tensor.to_local(), torch.zeros_like(local_tensor))
|
|
|
|
@with_comms
|
|
def test_partial_to_shard(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
partial_spec = [_Partial()]
|
|
my_rank = device_mesh.get_rank()
|
|
|
|
input_sizes_and_shard_dim = [
|
|
((self.world_size * 3, 3), 0),
|
|
((self.world_size * 3 + 1, 3), 0),
|
|
((self.world_size * 3 + 2, 3), 0),
|
|
((3, self.world_size * 3), 1),
|
|
((3, self.world_size * 3 + 1), 1),
|
|
((3, self.world_size * 3 + 2), 1),
|
|
]
|
|
|
|
for input_size, shard_dim in input_sizes_and_shard_dim:
|
|
shard_spec = [Shard(shard_dim)]
|
|
|
|
partial_local = torch.ones(input_size, device=self.device_type)
|
|
partial_tensor = DTensor.from_local(
|
|
partial_local, device_mesh, partial_spec, run_check=False
|
|
)
|
|
|
|
full_chunk_size = (
|
|
input_size[shard_dim] + self.world_size - 1
|
|
) // self.world_size
|
|
chunk_sizes = [
|
|
max(
|
|
min(input_size[shard_dim], full_chunk_size * (idx + 1))
|
|
- full_chunk_size * idx,
|
|
0,
|
|
)
|
|
for idx in range(self.world_size)
|
|
]
|
|
|
|
local_shape = list(input_size)
|
|
local_shape[shard_dim] = chunk_sizes[my_rank]
|
|
|
|
# test partial to shard, trigger reduce_scatter
|
|
scatter_shard_tensor = partial_tensor.redistribute(device_mesh, shard_spec)
|
|
self.assertEqual(scatter_shard_tensor.size(), partial_tensor.size())
|
|
self.assertEqual(scatter_shard_tensor.placements, shard_spec)
|
|
self.assertEqual(
|
|
scatter_shard_tensor.to_local(),
|
|
torch.ones(local_shape) * self.world_size,
|
|
)
|
|
|
|
|
|
class MultiDimRedistributeTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 8
|
|
|
|
@with_comms
|
|
def test_multi_dim_mesh(self):
|
|
devices = torch.arange(self.world_size)
|
|
for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]:
|
|
mesh_shape = torch.arange(self.world_size).view(-1, 2)
|
|
device_mesh = DeviceMesh(self.device_type, mesh_shape)
|
|
tensor_shape = (16, 24)
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
full_tensor = torch.randn(*tensor_shape)
|
|
else:
|
|
# these should be entirely ignored
|
|
# because distribute_tensor is expected to override shards in ranks != 0
|
|
full_tensor = torch.ones(*tensor_shape)
|
|
|
|
possibilities = [Replicate()] + [Shard(i) for i in range(full_tensor.ndim)]
|
|
all_outputs = list(itertools.product(*(mesh_shape.ndim * [possibilities])))
|
|
all_inputs = list(
|
|
itertools.product(*(mesh_shape.ndim * [possibilities + [_Partial()]]))
|
|
)
|
|
|
|
for inputs in all_inputs:
|
|
# if partial, temporarily make it Replicated, then replace replicated with partial afterwards
|
|
repl_inputs = [Replicate() if s.is_partial() else s for s in inputs]
|
|
dt = distribute_tensor(full_tensor, device_mesh, repl_inputs)
|
|
|
|
if repl_inputs != inputs:
|
|
# create a new DTensor reinterpreting some of the replicated entires as "Partial"
|
|
dt = DTensor.from_local(
|
|
dt.to_local(), device_mesh, inputs, run_check=False
|
|
)
|
|
|
|
for outputs in all_outputs:
|
|
# redistribute on target outputs
|
|
dt2 = dt.redistribute(device_mesh, outputs)
|
|
|
|
# replicate and then get first shard
|
|
local_full = dt2.redistribute(
|
|
device_mesh, device_mesh.ndim * [Replicate()]
|
|
).to_local()
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
self.assertEqual(local_full.shape, full_tensor.shape)
|
|
|
|
num_sums = 1
|
|
for idx, input in enumerate(inputs):
|
|
if input.is_partial():
|
|
num_sums *= mesh_shape.size(idx)
|
|
expected = num_sums * full_tensor
|
|
self.assertEqual(local_full, expected)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|