mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157640 Approved by: https://github.com/yewentao256, https://github.com/malfet
233 lines
8.0 KiB
Python
233 lines
8.0 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
|
|
import torch
|
|
from torch.distributed.tensor import (
|
|
distribute_module,
|
|
distribute_tensor,
|
|
DTensor,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
funcol = torch.ops.c10d_functional
|
|
|
|
|
|
class TestEmbeddingOp(DTensorTestBase):
|
|
def _apply_sharding(self, embedding_mod, shard_dim, device_mesh):
|
|
def shard_embedding_fn(name, module, device_mesh):
|
|
for name, param in module.named_parameters():
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, [Shard(shard_dim)])
|
|
)
|
|
module.register_parameter(name, dist_param)
|
|
|
|
sharded_embedding = distribute_module(
|
|
embedding_mod, device_mesh, shard_embedding_fn
|
|
)
|
|
return sharded_embedding
|
|
|
|
def _run_embedding_op_test(
|
|
self,
|
|
device_mesh,
|
|
shard_dim,
|
|
input_size,
|
|
num_embeddings,
|
|
embedding_dim,
|
|
**kwargs,
|
|
):
|
|
# Use same seed.
|
|
torch.manual_seed(0)
|
|
local_embedding = torch.nn.Embedding(
|
|
num_embeddings,
|
|
embedding_dim,
|
|
device=self.device_type,
|
|
**kwargs,
|
|
)
|
|
sharded_embedding = torch.nn.Embedding(
|
|
num_embeddings,
|
|
embedding_dim,
|
|
device=self.device_type,
|
|
**kwargs,
|
|
)
|
|
|
|
# Shard the parameter of local embedding and set it to sharded embedding.
|
|
sharded_embedding.weight = torch.nn.Parameter(
|
|
local_embedding.weight.detach().clone()
|
|
)
|
|
|
|
sharded_embedding = self._apply_sharding(
|
|
sharded_embedding, shard_dim, device_mesh
|
|
)
|
|
|
|
# Run sharded computation
|
|
torch.manual_seed(10)
|
|
inp = torch.randint(
|
|
0, num_embeddings, tuple(input_size), device=self.device_type
|
|
)
|
|
target = torch.empty(
|
|
*inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
|
|
).random_(0, 1)
|
|
dist_inp = distribute_tensor(inp, device_mesh, [Replicate()])
|
|
|
|
# fwd computation, ensure no comm happened
|
|
with CommDebugMode() as fwd_mode:
|
|
dist_output = sharded_embedding(dist_inp)
|
|
self.assertEqual(fwd_mode.get_total_counts(), 0)
|
|
|
|
output = dist_output.full_tensor()
|
|
# Run local computation
|
|
local_output = local_embedding(inp)
|
|
|
|
# Verify
|
|
self.assertEqual(local_output, output)
|
|
|
|
# Use a sample cross entry loss to verify backward and grad computation.
|
|
loss = torch.nn.CrossEntropyLoss()
|
|
emb_loss = loss(
|
|
output,
|
|
target,
|
|
)
|
|
emb_dup_loss = loss(
|
|
local_output,
|
|
target,
|
|
)
|
|
|
|
# local embedding backward
|
|
emb_dup_loss.backward()
|
|
|
|
# sharded embedding bwd computation, ensure no comm happened
|
|
with CommDebugMode() as bwd_mode:
|
|
emb_loss.backward()
|
|
self.assertEqual(bwd_mode.get_total_counts(), 0)
|
|
|
|
gradient = sharded_embedding.weight.grad.full_tensor()
|
|
|
|
local_grad = local_embedding.weight.grad
|
|
|
|
# Verify gradient.
|
|
self.assertEqual(gradient, local_grad)
|
|
|
|
# Validate for torch.nn.functional.embedding version.
|
|
local_output = torch.nn.functional.embedding(
|
|
inp,
|
|
local_embedding.weight,
|
|
**kwargs,
|
|
)
|
|
sharded_output = torch.nn.functional.embedding(
|
|
DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False),
|
|
sharded_embedding.weight,
|
|
**kwargs,
|
|
)
|
|
self.assertEqual(local_output, sharded_output.full_tensor())
|
|
|
|
@with_comms
|
|
def test_sharded_embedding_colwise(self):
|
|
mesh = self.build_device_mesh()
|
|
self._run_embedding_op_test(mesh, 1, [5, 4], 17, 12)
|
|
self._run_embedding_op_test(mesh, 1, [6, 7, 6], 21, 11)
|
|
self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13)
|
|
self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4, 7], 23, 16)
|
|
self._run_embedding_op_test(mesh, 1, [4], 15, 14)
|
|
self._run_embedding_op_test(mesh, 1, [34], 15, 14, padding_idx=10)
|
|
self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12)
|
|
|
|
@with_comms
|
|
def test_sharded_embedding_colwise_max_norm_errors(self):
|
|
mesh = self.build_device_mesh()
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"aten.embedding_renorm_.default does not have a sharding strategy registered.",
|
|
):
|
|
self._run_embedding_op_test(
|
|
mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0
|
|
)
|
|
|
|
@with_comms
|
|
def test_sharded_embedding_rowwise(self):
|
|
mesh = self.build_device_mesh()
|
|
# test correctness
|
|
self._run_embedding_op_test(mesh, 0, [5, 12], 16, 22)
|
|
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
|
|
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
|
|
|
|
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
|
|
|
# test collectives
|
|
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
|
|
sharded_embedding = self._apply_sharding(embedding_mod, 0, mesh)
|
|
inp = torch.randint(0, 10, (8, 8), device=self.device_type)
|
|
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
|
output = sharded_embedding(replicated_inp)
|
|
self.assertIsInstance(output.placements[0], _MaskPartial)
|
|
|
|
comm_mode = CommDebugMode()
|
|
|
|
with comm_mode:
|
|
output.full_tensor()
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
|
|
|
|
@with_comms
|
|
def test_multiple_embeddings_rowwise(self):
|
|
mesh = self.build_device_mesh()
|
|
|
|
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
|
|
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
|
|
|
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
|
|
|
# case 1: two embeddings with the same shape, thus sharing the underlying _MaskPartial
|
|
# and MaskBuffer, because of cache hit from sharding propagation
|
|
|
|
emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
|
|
sharded_emb1 = self._apply_sharding(emb1, 0, mesh)
|
|
output1 = sharded_emb1(replicated_inp)
|
|
|
|
emb2 = torch.nn.Embedding(10, 29, device=self.device_type)
|
|
sharded_emb2 = self._apply_sharding(emb2, 0, mesh)
|
|
output2 = sharded_emb2(replicated_inp)
|
|
|
|
partial_placement1 = output1.placements[0]
|
|
self.assertIsInstance(partial_placement1, _MaskPartial)
|
|
output1.full_tensor()
|
|
|
|
partial_placement2 = output2.placements[0]
|
|
self.assertIsInstance(partial_placement2, _MaskPartial)
|
|
output2.full_tensor()
|
|
|
|
self.assertTrue(id(partial_placement1), id(partial_placement2))
|
|
|
|
# case 2: two embeddings with the same logical_dim_size, but different logical_shape
|
|
# thus they will have different _MaskPartial placements (with no cache hit)
|
|
|
|
emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
|
|
sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
|
|
output3 = sharded_emb3(replicated_inp)
|
|
partial_placement3 = output3.placements[0]
|
|
self.assertIsInstance(partial_placement3, _MaskPartial)
|
|
output2.full_tensor()
|
|
|
|
# not equal because of different logical_shape, despite of same logical_dim_size
|
|
self.assertNotEqual(partial_placement1, partial_placement3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|