mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DTensor][BE] improve DTensor ops correctness check utils (#158112)
**Summary** Implemented the test pattern described in https://github.com/pytorch/pytorch/pull/157991#discussion_r2196363170 as a util method in `DTensorTestBase`. The difference to `DTensorTestBase._test_op` is: 1. allowing users to specify the `Partial` placement. 2. supporting tree-like output structure. **Test** so far only adopt `DTensorTestBase._test_op_on_dtensor` in `DistTensorOpsTest.test_split_on_partial`. `pytest test/distributed/tensor/test_tensor_ops.py -s -k test_split_on_partial` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158112 Approved by: https://github.com/Skylion007, https://github.com/zpcore ghstack dependencies: #158051
This commit is contained in:
parent
4c1fabf2c9
commit
add0b450bd
|
|
@ -2,7 +2,6 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
|
|
@ -725,24 +724,17 @@ class DistTensorOpsTest(DTensorTestBase):
|
|||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
partial_tensor = torch.randn(8, 8, device=self.device_type)
|
||||
replicate_tensor = partial_tensor.detach().clone()
|
||||
replicate_tensor = funcol.all_reduce(
|
||||
replicate_tensor, reduce_op, mesh
|
||||
) # all reduce to full tensor
|
||||
replicate_tensor_list = replicate_tensor.split(split_size, dim=split_dim)
|
||||
|
||||
partial_dt = DTensor.from_local(
|
||||
local_tensor=partial_tensor,
|
||||
device_mesh=mesh,
|
||||
placements=[Partial(reduce_op=reduce_op)],
|
||||
)
|
||||
partial_dt_list = partial_dt.split(split_size, dim=split_dim)
|
||||
|
||||
replicate_dt_full_tensor_list = [dt.full_tensor() for dt in partial_dt_list]
|
||||
for replicate_tensor, replicate_dt_full_tensor in zip(
|
||||
replicate_tensor_list, replicate_dt_full_tensor_list
|
||||
):
|
||||
self.assertEqual(replicate_tensor, replicate_dt_full_tensor)
|
||||
self._test_op_on_dtensor(
|
||||
torch.split,
|
||||
partial_dt,
|
||||
split_size,
|
||||
dim=split_dim,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from torch._utils import _get_device_module
|
|||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Placement,
|
||||
Replicate,
|
||||
Shard,
|
||||
|
|
@ -403,6 +404,32 @@ class DTensorTestBase(MultiProcessTestCase):
|
|||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None:
|
||||
"""
|
||||
This function checks ``op_call(dtensor).full_tensor() == op_call(dtensor.full_tensor())``.
|
||||
Unlike _test_op where the DTensor sharding is generated by DTensorConverter,
|
||||
this function takes in DTensor object directly as argument and test the equality
|
||||
of calling op on full_tensor() and DTensor.
|
||||
"""
|
||||
# call full_tensor() on DTensor args/kwargs
|
||||
args_flattened, args_spec = tree_flatten(args)
|
||||
full_tensor_args_flattened = tuple(
|
||||
arg.full_tensor().detach().clone() if isinstance(arg, DTensor) else arg
|
||||
for arg in args_flattened
|
||||
)
|
||||
full_tensor_args = tree_unflatten(full_tensor_args_flattened, args_spec)
|
||||
full_tensor_kwargs = {
|
||||
k: v.full_tensor() if isinstance(v, DTensor) else v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
out_flattened, _ = tree_flatten(
|
||||
op_call(*full_tensor_args, **full_tensor_kwargs)
|
||||
)
|
||||
d_out_flattened, _ = tree_flatten(op_call(*args, **kwargs))
|
||||
d_out_full_tensor_flattened = [dt.full_tensor() for dt in d_out_flattened]
|
||||
self.assertEqual(out_flattened, d_out_full_tensor_flattened)
|
||||
|
||||
# pyre-ignore[2]:
|
||||
def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
|
||||
out = op_call(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user