[DTensor][BE] improve DTensor ops correctness check utils (#158112)

**Summary**
Implemented the test pattern described in https://github.com/pytorch/pytorch/pull/157991#discussion_r2196363170 as a util method in `DTensorTestBase`. The difference to `DTensorTestBase._test_op` is:
1. allowing users to specify the `Partial` placement.
2. supporting tree-like output structure.

**Test**
so far only adopt `DTensorTestBase._test_op_on_dtensor` in `DistTensorOpsTest.test_split_on_partial`.
`pytest test/distributed/tensor/test_tensor_ops.py -s -k test_split_on_partial`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158112
Approved by: https://github.com/Skylion007, https://github.com/zpcore
ghstack dependencies: #158051
This commit is contained in:
Xilun Wu 2025-07-14 15:58:20 -07:00 committed by PyTorch MergeBot
parent 4c1fabf2c9
commit add0b450bd
2 changed files with 33 additions and 14 deletions

View File

@ -2,7 +2,6 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
@ -725,24 +724,17 @@ class DistTensorOpsTest(DTensorTestBase):
mesh = init_device_mesh(self.device_type, (self.world_size,))
partial_tensor = torch.randn(8, 8, device=self.device_type)
replicate_tensor = partial_tensor.detach().clone()
replicate_tensor = funcol.all_reduce(
replicate_tensor, reduce_op, mesh
) # all reduce to full tensor
replicate_tensor_list = replicate_tensor.split(split_size, dim=split_dim)
partial_dt = DTensor.from_local(
local_tensor=partial_tensor,
device_mesh=mesh,
placements=[Partial(reduce_op=reduce_op)],
)
partial_dt_list = partial_dt.split(split_size, dim=split_dim)
replicate_dt_full_tensor_list = [dt.full_tensor() for dt in partial_dt_list]
for replicate_tensor, replicate_dt_full_tensor in zip(
replicate_tensor_list, replicate_dt_full_tensor_list
):
self.assertEqual(replicate_tensor, replicate_dt_full_tensor)
self._test_op_on_dtensor(
torch.split,
partial_dt,
split_size,
dim=split_dim,
)
if __name__ == "__main__":

View File

@ -17,6 +17,7 @@ from torch._utils import _get_device_module
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Placement,
Replicate,
Shard,
@ -403,6 +404,32 @@ class DTensorTestBase(MultiProcessTestCase):
super().setUp()
self._spawn_processes()
def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None:
"""
This function checks ``op_call(dtensor).full_tensor() == op_call(dtensor.full_tensor())``.
Unlike _test_op where the DTensor sharding is generated by DTensorConverter,
this function takes in DTensor object directly as argument and test the equality
of calling op on full_tensor() and DTensor.
"""
# call full_tensor() on DTensor args/kwargs
args_flattened, args_spec = tree_flatten(args)
full_tensor_args_flattened = tuple(
arg.full_tensor().detach().clone() if isinstance(arg, DTensor) else arg
for arg in args_flattened
)
full_tensor_args = tree_unflatten(full_tensor_args_flattened, args_spec)
full_tensor_kwargs = {
k: v.full_tensor() if isinstance(v, DTensor) else v
for k, v in kwargs.items()
}
out_flattened, _ = tree_flatten(
op_call(*full_tensor_args, **full_tensor_kwargs)
)
d_out_flattened, _ = tree_flatten(op_call(*args, **kwargs))
d_out_full_tensor_flattened = [dt.full_tensor() for dt in d_out_flattened]
self.assertEqual(out_flattened, d_out_full_tensor_flattened)
# pyre-ignore[2]:
def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
out = op_call(*args, **kwargs)